diff --git a/.gitignore b/.gitignore index 81aad72480..acd56e6a01 100644 --- a/.gitignore +++ b/.gitignore @@ -29,6 +29,9 @@ log dask-worker-space/ tmp/ +## files pickled in notebook when ran during python docstring generation +docs/source/*.model + ## eclipse .project .cproject diff --git a/CHANGELOG.md b/CHANGELOG.md index 7200fc7381..01bf52ed4a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,7 @@ - PR #2594: Confidence intervals for ARIMA forecasts - PR #2607: Add support for probability estimates in SVC - PR #2618: SVM class and sample weights +- PR #2635: Decorator to generate docstrings with autodetection of parameters - PR #2270: Multi class MNMG RF - PR #2661: CUDA-11 support for single-gpu code - PR #2322: Sparse FIL forests with 8-byte nodes diff --git a/docs/source/api.rst b/docs/source/api.rst index f2fe9ecd92..03ad255aa1 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -201,6 +201,12 @@ Mini Batch SGD Regressor .. autoclass:: cuml.MBSGDRegressor :members: +Mutinomial Naive Bayes +---------------------- + +.. autoclass:: cuml.MultinomialNB + :members: + Stochastic Gradient Descent --------------------------- diff --git a/python/cuml/__init__.py b/python/cuml/__init__.py index 9a065a8829..c5a0a0167c 100644 --- a/python/cuml/__init__.py +++ b/python/cuml/__init__.py @@ -50,6 +50,8 @@ from cuml.metrics.cluster.adjustedrandindex import adjusted_rand_score from cuml.metrics.regression import r2_score +from cuml.naive_bayes.naive_bayes import MultinomialNB + from cuml.neighbors.nearest_neighbors import NearestNeighbors from cuml.preprocessing.LabelEncoder import LabelEncoder @@ -84,7 +86,6 @@ from cuml.common.memory_utils import set_global_output_type, using_output_type - # Version configuration from ._version import get_versions diff --git a/python/cuml/cluster/dbscan.pyx b/python/cuml/cluster/dbscan.pyx index b4764f800d..ae4862308b 100644 --- a/python/cuml/cluster/dbscan.pyx +++ b/python/cuml/cluster/dbscan.pyx @@ -30,6 +30,7 @@ from libc.stdlib cimport calloc, malloc, free from cuml.common.array import CumlArray from cuml.common.base import Base +from cuml.common.doc_utils import generate_docstring from cuml.common.handle cimport cumlHandle from cuml.common import input_to_cuml_array @@ -204,19 +205,17 @@ class DBSCAN(Base): if self.max_mbytes_per_batch is None: self.max_mbytes_per_batch = 0 + @generate_docstring(skip_parameters_heading=True) def fit(self, X, out_dtype="int32"): """ Perform DBSCAN clustering from features. Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy out_dtype: dtype Determines the precision of the output labels array. default: "int32". Valid values are { "int32", np.int32, - "int64", np.int64}. When the number of samples exceed + "int64", np.int64}. + """ self._set_n_features_in(X) self._set_output_type(X) @@ -321,21 +320,21 @@ class DBSCAN(Base): return self + @generate_docstring(skip_parameters_heading=True, + return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Cluster labels', + 'shape': '(n_samples, 1)'}) def fit_predict(self, X, out_dtype="int32"): """ - Performs clustering on input_gdf and returns cluster labels. + Performs clustering on X and returns cluster labels. Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features) - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - Returns - ------- - y : cuDF Series, shape (n_samples) - cluster labels + out_dtype: dtype Determines the precision of the output labels array. + default: "int32". Valid values are { "int32", np.int32, + "int64", np.int64}. + """ self.fit(X, out_dtype) return self.labels_ diff --git a/python/cuml/cluster/kmeans.pyx b/python/cuml/cluster/kmeans.pyx index 39ee32b365..469915bfb4 100644 --- a/python/cuml/cluster/kmeans.pyx +++ b/python/cuml/cluster/kmeans.pyx @@ -31,6 +31,7 @@ from libc.stdlib cimport calloc, malloc, free from cuml.common.array import CumlArray from cuml.common.base import Base +from cuml.common.doc_utils import generate_docstring from cuml.common.handle cimport cumlHandle from cuml.common import input_to_cuml_array from cuml.cluster.kmeans_utils cimport * @@ -140,7 +141,7 @@ class KMeans(Base): print(b) print("Calling fit") - kmeans_float = KMeans(n_clusters=2, n_gpu=-1) + kmeans_float = KMeans(n_clusters=2) kmeans_float.fit(b) print("labels:") @@ -306,21 +307,11 @@ class KMeans(Base): params.n_init = self.n_init self._params = params + @generate_docstring() def fit(self, X, sample_weight=None): """ Compute k-means clustering with X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - sample_weight : array-like (device or host) shape = (n_samples,), default=None # noqa - The weights for each observation in X. If None, all observations - are assigned equal weight. - """ self._set_n_features_in(X) self._set_output_type(X) @@ -407,21 +398,14 @@ class KMeans(Base): del(sample_weight_m) return self + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Cluster indexes', + 'shape': '(n_samples, 1)'}) def fit_predict(self, X, sample_weight=None): """ Compute cluster centers and predict cluster index for each sample. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - sample_weight : array-like (device or host) shape = (n_samples,), default=None # noqa - The weights for each observation in X. If None, all observations - are assigned equal weight. - """ return self.fit(X, sample_weight=sample_weight).labels_ @@ -522,21 +506,14 @@ class KMeans(Base): del(sample_weight_m) return self._labels_.to_output(out_type), inertia + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Cluster indexes', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=False, sample_weight=None): """ Predict the closest cluster each sample in X belongs to. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - Returns - ------- - labels : array - Which cluster each datapoint belongs to. """ labels, _ = self._predict_labels_inertia(X, @@ -544,21 +521,14 @@ class KMeans(Base): sample_weight=sample_weight) return labels + @generate_docstring(return_values={'name': 'X_new', + 'type': 'dense', + 'description': 'Transformed data', + 'shape': '(n_samples, n_clusters)'}) def transform(self, X, convert_dtype=False): """ Transform X to a cluster-distance space. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the transform method will, when necessary, - convert the input to the data type which was used to train the - model. This will increase memory used for the method. """ out_type = self._get_output_type(X) @@ -614,54 +584,29 @@ class KMeans(Base): del(X_m) return preds.to_output(out_type) + @generate_docstring(return_values={'name': 'score', + 'type': 'float', + 'description': 'Opposite of the value \ + of X on the K-means \ + objective.'}) def score(self, X, y=None, sample_weight=None, convert_dtype=True): """ Opposite of the value of X on the K-means objective. - Parameters - ---------- - X : array-like (device or host) shape (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - y : Ignored - Not used, present here for API consistency by convention. - sample_weight : array-like (device or host) of shape (n_samples,), - default=None. Acceptable formats: cuDF DataFrame, NumPy ndarray, - Numba device ndarray, cuda array interface compliant array like - CuPy. - convert_dtype : bool, optional (default = False) - When set to True, the transform method will, when necessary, - convert the input to the data type which was used to train the - model. This will increase memory used for the method. - - - Returns - ------- - score: float - Opposite of the value of X on the K-means objective. """ return -1 * self._predict_labels_inertia( X, convert_dtype=convert_dtype, sample_weight=sample_weight)[1] + @generate_docstring(return_values={'name': 'X_new', + 'type': 'dense', + 'description': 'Transformed data', + 'shape': '(n_samples, n_clusters)'}) def fit_transform(self, X, convert_dtype=False): """ Compute clustering and transform X to cluster-distance space. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the fit_transform method will automatically - convert the input to the data type which was used to train the - model. This will increase memory used for the method. - """ return self.fit(X).transform(X, convert_dtype=convert_dtype) diff --git a/python/cuml/common/base.pyx b/python/cuml/common/base.pyx index aa91bf5f6c..1920ea55f8 100644 --- a/python/cuml/common/base.pyx +++ b/python/cuml/common/base.pyx @@ -29,6 +29,7 @@ import inspect from cudf.core import Series as cuSeries from cudf.core import DataFrame as cuDataFrame from cuml.common.array import CumlArray +from cuml.common.doc_utils import generate_docstring from cupy import ndarray as cupyArray from numba.cuda import devicearray as numbaArray from numpy import ndarray as numpyArray @@ -336,26 +337,16 @@ class RegressorMixin: _estimator_type = "regressor" + @generate_docstring(return_values={'name': 'score', + 'type': 'float', + 'description': 'R^2 of self.predict(X) ' + 'wrt. y.'}) def score(self, X, y, **kwargs): - """Scoring function for regression estimators + """ + Scoring function for regression estimators Returns the coefficient of determination R^2 of the prediction. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Test samples on which we predict - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - y : array-like (device or host) shape = (n_samples, n_features) - Ground truth values for predict(X) - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - Returns - ------- - score : float - R^2 of self.predict(X) wrt. y. """ from cuml.metrics.regression import r2_score @@ -373,21 +364,16 @@ class ClassifierMixin: _estimator_type = "classifier" + @generate_docstring(return_values={'name': 'score', + 'type': 'float', + 'description': 'Accuracy of \ + self.predict(X) wrt. y \ + (fraction where y == \ + pred_y)'}) def score(self, X, y, **kwargs): """ Scoring function for classifier estimators based on mean accuracy. - Parameters - ---------- - X : [cudf.DataFrame] - Test samples on which we predict - y : [cudf.Series, device array, or numpy array] - Ground truth values for predict(X) - - Returns - ------- - score : float - Accuracy of self.predict(X) wrt. y (fraction where y == pred_y) """ from cuml.metrics.accuracy import accuracy_score from cuml.common import input_to_dev_array diff --git a/python/cuml/common/doc_utils.py b/python/cuml/common/doc_utils.py new file mode 100644 index 0000000000..bd2e8dfcbd --- /dev/null +++ b/python/cuml/common/doc_utils.py @@ -0,0 +1,430 @@ +# +# Copyright (c) 2019-2020, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +""" +Decorators to generate common docstrings in the codebase. +Dense datatypes are currently the default, if you're a developer that landed +here, the docstrings apply to every parameter to which the decorators +are applied. The docstrings are generated at import time. + +There are 2 decorators: +- generate_docstring: Meant to be used by fit/predict/et.al methods that have + the typical signatures (i.e. fit(x,y) or predict(x)). It detects the + parameters and default values and generates the appropriate docstring, + with some configurability for shapes and formats. +- insert_into_docstring: More flexible but less automatic method, meant to be + used by functions that use our common dense or sparse datatypes, but have + many more custom parameters that are particular to the class(es) as opposed + to being common in the codebase. Allows to keep our documentation up to + date and correct with minimal changes by keeping our common datatypes + concentrated here. NearestNeigbors is a good example of this use case. + +More data types can be added as we need them. + +cuml.dask datatype version of the docstrings will come in a future update. + +""" + +from inspect import signature + + +_parameters_docstrings = { + 'dense': + '{name} : array-like (device or host) shape = {shape}\n' + ' Dense matrix containing floats or doubles.\n' + ' Acceptable formats: CUDA array interface compliant objects like\n' + ' CuPy, cuDF DataFrame/Series, NumPy ndarray and Pandas\n' + ' DataFrame/Series.', + + 'dense_anydtype': + '{name} : array-like (device or host) shape = {shape}\n' + ' Dense matrix of any dtype.\n' + ' Acceptable formats: CUDA array interface compliant objects like\n' + ' CuPy, cuDF DataFrame/Series, NumPy ndarray and Pandas\n' + ' DataFrame/Series.', + + 'dense_intdtype': + '{name} : array-like (device or host) shape = {shape}\n' + ' Dense matrix of type np.int32.\n' + ' Acceptable formats: CUDA array interface compliant objects like\n' + ' CuPy, cuDF DataFrame/Series, NumPy ndarray and Pandas\n' + ' DataFrame/Series.', + + 'sparse': + '{name} : sparse array-like (device) shape = {shape}\n' + ' Dense matrix containing floats or doubles.\n' + ' Acceptable formats: cupy.sparse', + + 'dense_sparse': + '{name} : array-like (device or host) shape = {shape}\n' + ' Dense or sparse matrix containing floats or doubles.\n' + ' Acceptable dense formats: CUDA array interface compliant objects like\n' # noqa + ' CuPy, cuDF DataFrame/Series, NumPy ndarray and Pandas\n' + ' DataFrame/Series.', + + 'convert_dtype_fit': + 'convert_dtype : bool, optional (default = {default})\n' + ' When set to True, the train method will, when necessary, convert\n' + ' y to be the same data type as X if they differ. This\n' + ' will increase memory used for the method.', + + 'convert_dtype_other': + 'convert_dtype : bool, optional (default = {default})\n' + ' When set to True, the {func_name} method will, when necessary,\n' + ' convert the input to the data type which was used to train the\n' + ' model. This will increase memory used for the method.', + + 'convert_dtype_single': + 'convert_dtype : bool, optional (default = {default})\n' + ' When set to True, the method will automatically\n' + ' convert the inputs to {dtype}.', + + 'sample_weight': + 'sample_weight : array-like (device or host) shape = (n_samples,), default={default}\n' # noqa + ' The weights for each observation in X. If None, all observations\n' + ' are assigned equal weight.\n' + ' Acceptable dense formats: CUDA array interface compliant objects like\n' # noqa + ' CuPy, cuDF DataFrame/Series, NumPy ndarray and Pandas\n' + ' DataFrame/Series.', # noqa + 'return_sparse': + 'return_sparse : bool, optional (default = {default})\n' + ' Ignored when the model is not fit on a sparse matrix\n' + ' If True, the method will convert the result to a\n' + ' cupy.sparse.csr_matrix object.\n' + ' NOTE: Currently, there is a loss of information when converting\n' + ' to csr matrix (cusolver bug). Default will be switched to True\n' + ' once this is solved.', + + 'sparse_tol': + 'sparse_tol : float, optional (default = {default})\n' + ' Ignored when return_sparse=False.\n' + ' If True, values in the inverse transform below this parameter\n' + ' are clipped to 0.' +} + +_parameter_possible_values = ['name', + 'type', + 'shape', + 'default', + 'description', + 'accepted'] + +_return_values_docstrings = { + 'dense': + '{name} : cuDF, CuPy or NumPy object depending on cuML\'s output type configuration, shape = {shape}\n' # noqa + ' {description}\n\n For more information on how to configure cuML\'s output type,\n' # noqa + ' refer to: `Output Data Type Configuration`_.', # noqa + + 'dense_sparse': + '{name} : cuDF, CuPy or NumPy object depending on cuML\'s output type configuration, cupy.sparse for sparse output, shape = {shape}\n' # noqa + ' {description}\n\n For more information on how to configure cuML\'s dense output type,\n' # noqa + ' refer to: `Output Data Type Configuration`_.', # noqa + + 'dense_datatype': + 'cuDF, CuPy or NumPy object depending on cuML\'s output type' + 'configuration, shape ={shape}', + + 'dense_sparse_datatype': + 'cuDF, CuPy or NumPy object depending on cuML\'s output type' + 'configuration, shape ={shape}', + + 'custom_type': + '{name} : {type}\n' + ' {description}' +} + +_return_values_possible_values = ['name', + 'type', + 'shape', + 'description'] + +_simple_params = ['return_sparse', + 'sparse_tol', + 'sample_weight'] + + +def generate_docstring(X='dense', + X_shape='(n_samples, n_features)', + y='dense', + y_shape='(n_samples, 1)', + convert_dtype_cast=False, + skip_parameters=[], + skip_parameters_heading=False, + prepend_parameters=True, + parameters=False, + return_values=False): + """ + Decorator to generate dostrings of common functions in the codebase. + It will auto detect what parameters and default values the function has. + Unfortunately due to using cython, we cannot (cheaply) do detection of + return values. + + Currently auto detected variables include: + - X + - y + - convert_dtype + - sample_weights + - return_sparse + - sparse_tol + + Typical usage scenarios: + + Examples + -------- + + # for a function that passes all dense parameters, no need to specify + # anything, and the decorator auto detects the parameters and defaults + + @generate_docstring() + def fit(self, X, y, convert_dtype=True): + + # for a function that takes X as dense or sparse + + @generate_docstring(X='dense_sparse') + def fit(self, X, y, sample_weight=None): + + # to specify return values + + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) + + + Parameters + ----------- + X : str (default = 'dense') + Data type of variable X. Currently accepted types are: dense, + dense_anydtype, dense_intdtype, sparse, dense_sparse + X_shape : str (default = '(n_samples, n_features)') + Shape of variable X + y : str (default = 'dense') + Data type of variable y. Currently accepted types are: dense, + dense_anydtype, dense_intdtype, sparse, dense_sparse + y_shape : str (default = '(n_samples, 1)') + Shape of variable y + convert_dtype_cast : Boolean or str (default = False) + If not false, use it to specify when convert_dtype is used to convert + to a single specific dtype (as opposed to converting the dtype of one + variable to the dtype of another for example). Example of this is how + NearestNeighbors and UMAP use convert_dtype to convert inputs to + np.float32. + skip_parameters : list of str (default = []) + Use if you want the decorator to skip generating a docstring entry + for a specific parameter + skip_parameters_heading : boolean (default = False) + Set to True to not generate the Parameters section heading + prepend_parameters : boolean (default = True) + Use when setting skip_parameters_heading to True, so that the + parameters inserted by the decorator are inserted before the + parameters you already have in your docstring. + return_values : dict or list of dicts (default = False) + Use to generate docstrings of return values. One dictionary per + return value, this is the format: + {'name': 'name_of_variable', + 'type': 'data type of returned value', + 'description': 'Description of variable', + 'shape': 'shape of returned variable'} + + If type is one of dense or dense_sparse then the type is generated + from the corresponding entry in _return_values_docstrings. Otherwise + the type is used as specified. + """ + + def deco(func): + params = signature(func).parameters + + # Add parameter section header if needed, can be skipped + if(('X' in params or 'y' in params or parameters) and not + skip_parameters_heading): + + func.__doc__ += \ + '\nParameters\n----------\n' + + # Check if we want to prepend the parameters + if skip_parameters_heading and prepend_parameters: + loc_pars = func.__doc__.find("----------") + 11 + current_params_in_docstring = \ + func.__doc__[loc_pars:] + + func.__doc__ = func.__doc__[:loc_pars] + + # Process each parameter + for par, value in params.items(): + if par == 'self': + pass + + # X and y are the most common + elif par in ['X', 'y'] and par not in skip_parameters: + func.__doc__ += \ + _parameters_docstrings[X].format(name=par, + shape=X_shape) + + # convert_dtype requires some magic to distinguish + # whether we use the fit version or the version + # for the other methods. + elif par == 'convert_dtype' and par not in skip_parameters: + if not convert_dtype_cast: + if func.__name__ == 'fit': + k = 'convert_dtype_fit' + else: + k = 'convert_dtype_other' + + func.__doc__ += \ + _parameters_docstrings[k].format( + default=params['convert_dtype'].default, + func_name=func.__name__ + ) + + else: + func.__doc__ += \ + _parameters_docstrings['convert_dtype_single'].format( + default=params['convert_dtype'].default, + dtype=convert_dtype_cast + ) + + # All other parameters only take a default (for now). + else: + if par in _simple_params: + func.__doc__ += \ + _parameters_docstrings[par].format( + default=params[par].default + ) + func.__doc__ += '\n\n' + + if skip_parameters_heading and prepend_parameters: + # indexing at 8 to match indentation of inserted parameters + # this can be replaced with indentation detection + # https://github.com/rapidsai/cuml/issues/2714 + func.__doc__ += current_params_in_docstring[8:] + + # Add return section header if needed, no option to skip currently. + if(return_values): + func.__doc__ += \ + '\nReturns\n----------\n' + + # convenience call to allow users to pass a single return + # value as a dictionary instead of a list of dictionaries + rets = [return_values] if not isinstance(return_values, list) \ + else return_values + + # process each entry in the return_values + # auto naming of predicted variable names will be a + # future improvement + for ret in rets: + if ret['type'] in _return_values_docstrings: + key = ret['type'] + # non custom types don't take the type parameter + del ret['type'] + else: + key = 'custom_type' + + # ret is already a dictionary, we just use it for the named + # parameters + func.__doc__ += \ + _return_values_docstrings[key].format( + **ret + ) + func.__doc__ += '\n\n' + + return func + return deco + + +def insert_into_docstring(parameters=False, + return_values=False): + """ + Decorator to insert a single entry into an existing docstring. Use + standard {} format parameters in your docstring, and then use this + decorator to insert the standard type information for that variable. + + Examples + -------- + + @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], + return_values=[('dense', '(n_samples, n_features)'), + ('dense', + '(n_samples, n_features)')]) + def kneighbors(self, X=None, n_neighbors=None, return_distance=True, + convert_dtype=True): + \""" + Query the GPU index for the k nearest neighbors of column vectors in X. + + Parameters + ---------- + X : {} + + n_neighbors : Integer + Number of neighbors to search. If not provided, the n_neighbors + from the model instance is used (default=10) + + return_distance: Boolean + If False, distances will not be returned + + convert_dtype : bool, optional (default = True) + When set to True, the kneighbors method will automatically + convert the inputs to np.float32. + + Returns + ------- + distances : {} + The distances of the k-nearest neighbors for each column vector + in X + + indices : {} + The indices of the k-nearest neighbors for each column vector in X + \""" + + Parameters + ---------- + parameters : list of tuples + List of tuples, each tuple containing: (type, shape) for the type + and shape of each parameter to be inserted. Current accepted values + are `dense` and `dense_sparse`. + return_values : list of tuples + List of tuples, each tuple containing: (type, shape) for the type + and shape of each parameter to be inserted. Current accepted values + are `dense` and `dense_sparse`. + + """ + + def deco(func): + # List of parameters to use in `format` call of the docstring + to_add = [] + + # See if we need to add parameter data types + if parameters: + for par in parameters: + to_add.append( + _parameters_docstrings[par[0]][9:].format(shape=par[1]) + ) + + # See if we need to add return value data types + if return_values: + for ret in return_values: + to_add.append( + _return_values_docstrings[ret[0] + '_datatype'].format( + shape=ret[1] + ) + ) + + if(len(to_add) > 0): + func.__doc__ = str(func.__doc__).format(*to_add) + + func.__doc__ += '\n\n' + + return func + return deco diff --git a/python/cuml/decomposition/pca.pyx b/python/cuml/decomposition/pca.pyx index 8242f8490c..1dd5a29baf 100644 --- a/python/cuml/decomposition/pca.pyx +++ b/python/cuml/decomposition/pca.pyx @@ -39,6 +39,7 @@ from cython.operator cimport dereference as deref from cuml.common.array import CumlArray from cuml.common.base import Base from cuml.common.base import _input_to_type +from cuml.common.doc_utils import generate_docstring from cuml.common.handle cimport cumlHandle from cuml.decomposition.utils cimport * from cuml.common import input_to_cuml_array @@ -412,25 +413,10 @@ class PCA(Base): return self + @generate_docstring(X='dense_sparse') def fit(self, X, y=None): """ - Fit the model with X. - - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - sparse array-like (device) shape = (n_samples, n_features) - Acceptable formats: cupy.sparse - - y : ignored - - Returns - ------- - cluster labels + Fit the model with X. y is currently ignored. """ self._set_n_features_in(X) @@ -500,23 +486,15 @@ class PCA(Base): return self + @generate_docstring(X='dense_sparse', + return_values={'name': 'trans', + 'type': 'dense_sparse', + 'description': 'Transformed values', + 'shape': '(n_samples, n_components)'}) def fit_transform(self, X, y=None): """ Fit the model with X and apply the dimensionality reduction on X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - training data (floats or doubles), where n_samples is the number of - samples, and n_features is the number of features. - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : ignored - - Returns - ------- - X_new : cuDF DataFrame, shape (n_samples, n_components) """ return self.fit(X).transform(X) @@ -557,6 +535,11 @@ class PCA(Base): X_inv, _, _, _ = input_to_cuml_array(X_inv, order='K') return X_inv.to_output(out_type) + @generate_docstring(X='dense_sparse', + return_values={'name': 'X_inv', + 'type': 'dense_sparse', + 'description': 'Transformed values', + 'shape': '(n_samples, n_features)'}) @with_cupy_rmm def inverse_transform(self, X, convert_dtype=False, return_sparse=False, sparse_tol=1e-10): @@ -565,40 +548,6 @@ class PCA(Base): In other words, return an input X_original whose transform would be X. - Parameters - ---------- - X : dense array-like (device or host) shape = (n_samples, n_features) - New data (floats or doubles), where n_samples is the number of - samples and n_components is the number of components. - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - sparse array-like (device) shape = (n_samples, n_features) - Acceptable formats: cupy.sparse - - convert_dtype : bool, optional (default = False) - When set to True, the inverse_transform method will automatically - convert the input to the data type which was used to train the - model. This will increase memory used for the method. - - return_sparse : bool, optional (default = False) - Ignored when the model is not fit on a sparse matrix - If True, the method will convert the inverse transform to a - cupy.sparse.csr_matrix object - - NOTE: Currently, there is a loss of information when converting - to csr matrix (cusolver bug). Default can be switched to True - once this is solved - - sparse_tol : float, optional (default = 1e-10) - Ignored when return_sparse=False - If True, values in the inverse transform below this parameter - are clipped to 0 - - Returns - ------- - X_original : cuDF DataFrame, shape (n_samples, n_features) - """ out_type = self._get_output_type(X) @@ -700,6 +649,11 @@ class PCA(Base): input_to_cuml_array(X_transformed, order='K') return X_transformed.to_output(out_type) + @generate_docstring(X='dense_sparse', + return_values={'name': 'trans', + 'type': 'dense_sparse', + 'description': 'Transformed values', + 'shape': '(n_samples, n_components)'}) @with_cupy_rmm def transform(self, X, convert_dtype=False): """ @@ -708,27 +662,6 @@ class PCA(Base): X is projected on the first principal components previously extracted from a training set. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - New data (floats or doubles), where n_samples is the number of - samples and n_components is the number of components. - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - sparse array-like (device) shape = (n_samples, n_features) - Acceptable formats: cupy.sparse - - convert_dtype : bool, optional (default = False) - When set to True, the transform method will automatically - convert the input to the data type which was used to train the - model. This will increase memory used for the method. - - - Returns - ------- - X_new : cuDF DataFrame, shape (n_samples, n_components) - """ out_type = self._get_output_type(X) diff --git a/python/cuml/decomposition/tsvd.pyx b/python/cuml/decomposition/tsvd.pyx index 2da71861e3..536fef3935 100644 --- a/python/cuml/decomposition/tsvd.pyx +++ b/python/cuml/decomposition/tsvd.pyx @@ -33,6 +33,7 @@ import cuml from cuml.common.array import CumlArray from cuml.common.base import Base +from cuml.common.doc_utils import generate_docstring from cuml.common.handle cimport cumlHandle from cuml.decomposition.utils cimport * from cuml.common import input_to_cuml_array @@ -292,18 +293,10 @@ class TruncatedSVD(Base): dtype=self.dtype) self._noise_variance_ = CumlArray.zeros(1, dtype=self.dtype) + @generate_docstring() def fit(self, X, y=None): """ - Fit LSI model on training cudf DataFrame X. - - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : ignored + Fit LSI model on training cudf DataFrame X. y is currently ignored. """ @@ -311,23 +304,15 @@ class TruncatedSVD(Base): return self + @generate_docstring(return_values={'name': 'trans', + 'type': 'dense', + 'description': 'Reduced version of X', + 'shape': '(n_samples, n_components)'}) def fit_transform(self, X, y=None): """ Fit LSI model to X and perform dimensionality reduction on X. + y is currently ignored. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : ignored - - Returns - ------- - X_new : cuDF DataFrame, shape (n_samples, n_components) - Reduced version of X as a dense cuDF DataFrame """ self._set_output_type(X) self._set_n_features_in(X) @@ -386,26 +371,15 @@ class TruncatedSVD(Base): out_type = self._get_output_type(X) return _trans_input_.to_output(out_type) + @generate_docstring(return_values={'name': 'X_original', + 'type': 'dense', + 'description': 'X in original space', + 'shape': '(n_samples, n_features)'}) def inverse_transform(self, X, convert_dtype=False): """ Transform X back to its original space. - Returns a cuDF DataFrame X_original whose transform would be X. - - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - convert_dtype : bool, optional (default = False) - When set to True, the inverse_transform method will automatically - convert the input to the data type which was used to train the - model. This will increase memory used for the method. - - Returns - ------- - X_original : cuDF DataFrame, shape (n_samples, n_features) - Note that this is always a dense cuDF DataFrame. + Returns X_original whose transform would be X. + """ trans_input, n_rows, _, dtype = \ @@ -447,27 +421,14 @@ class TruncatedSVD(Base): out_type = self._get_output_type(X) return input_data.to_output(out_type) + @generate_docstring(return_values={'name': 'X_new', + 'type': 'dense', + 'description': 'Reduced version of X', + 'shape': '(n_samples, n_components)'}) def transform(self, X, convert_dtype=False): """ Perform dimensionality reduction on X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the transform method will automatically - convert the input to the data type which was used to train the - model. - - Returns - ------- - X_new : cuDF DataFrame, shape (n_samples, n_components) - Reduced version of X. This will always be a dense DataFrame. - """ input, n_rows, _, dtype = \ input_to_cuml_array(X, check_dtype=self.dtype, diff --git a/python/cuml/ensemble/randomforestclassifier.pyx b/python/cuml/ensemble/randomforestclassifier.pyx index fa537c4e0f..06add5b54f 100644 --- a/python/cuml/ensemble/randomforestclassifier.pyx +++ b/python/cuml/ensemble/randomforestclassifier.pyx @@ -29,6 +29,8 @@ import cuml.common.logger as logger from cuml import ForestInference from cuml.common.array import CumlArray from cuml.common.base import ClassifierMixin +from cuml.common.doc_utils import generate_docstring +from cuml.common.doc_utils import insert_into_docstring from cuml.common.handle import Handle from cuml.common import input_to_cuml_array, rmm_cupy_ary @@ -383,26 +385,15 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): algo=algo, fil_sparse_format=fil_sparse_format) - """ - TODO : Move functions duplicated in the RF classifier and regressor - to a shared file. Cuml issue #1854 has been created to track this. - """ - + @generate_docstring(skip_parameters_heading=True, + y='dense_intdtype', + convert_dtype_cast='np.float32') def fit(self, X, y, convert_dtype=True): """ Perform Random Forest Classification on the input data Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (int32) of shape (n_samples, 1). - Acceptable formats: NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - These labels should be contiguous integers from 0 to n_classes. convert_dtype : bool, optional (default = True) When set to True, the fit method will, when necessary, convert y to be of dtype int32. This will increase memory used for @@ -531,6 +522,8 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): del(X_m) return preds.to_output(output_type=out_type, output_dtype=out_dtype) + @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], + return_values=[('dense', '(n_samples, 1)')]) def predict(self, X, predict_model="GPU", output_class=True, threshold=0.5, algo='auto', num_classes=None, @@ -541,10 +534,7 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy + X : {} predict_model : String (default = 'GPU') 'GPU' to predict using the GPU, 'CPU' otherwise. The 'GPU' can only be used if the model was trained on float32 data and `X` is float32 @@ -591,8 +581,7 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): Returns ---------- - y : (same as the input datatype) - Dense vector (ints, floats, or doubles) of shape (n_samples, 1) + y : {} """ if num_classes: warnings.warn("num_classes is deprecated and will be removed" @@ -683,6 +672,8 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): del(X_m) return preds.to_output(out_type) + @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], + return_values=[('dense', '(n_samples, 1)')]) def predict_proba(self, X, output_class=True, threshold=0.5, algo='auto', num_classes=None, convert_dtype=True, @@ -694,10 +685,7 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy + X : {} output_class: boolean (default = True) This is optional and required only while performing the predict operation on the GPU. @@ -739,10 +727,7 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): Returns ------- - y : (same as the input datatype) - Dense vector (float) of shape (n_samples, 1). The datatype of y - depend on the value of 'output_type' varaible specified by the - user while intializing the model. + y : {} """ if self.dtype == np.float64: raise TypeError("GPU based predict only accepts np.float32 data. \ @@ -771,6 +756,8 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): return preds_proba + @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)'), + ('dense_intdtype', '(n_samples, 1)')]) def score(self, X, y, threshold=0.5, algo='auto', num_classes=None, predict_model="GPU", convert_dtype=True, fil_sparse_format='auto'): @@ -779,12 +766,8 @@ class RandomForestClassifier(BaseRandomForestModel, ClassifierMixin): Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - y : NumPy - Dense vector (int) of shape (n_samples, 1) + X : {} + y : {} algo : string (default = 'auto') This is optional and required only while performing the predict operation on the GPU. diff --git a/python/cuml/ensemble/randomforestregressor.pyx b/python/cuml/ensemble/randomforestregressor.pyx index 61b6ab3e0f..a4d606f090 100644 --- a/python/cuml/ensemble/randomforestregressor.pyx +++ b/python/cuml/ensemble/randomforestregressor.pyx @@ -29,6 +29,8 @@ from cuml import ForestInference from cuml.common.array import CumlArray from cuml.common.base import RegressorMixin +from cuml.common.doc_utils import generate_docstring +from cuml.common.doc_utils import insert_into_docstring from cuml.common.handle import Handle from cuml.common import input_to_cuml_array, rmm_cupy_ary @@ -365,30 +367,11 @@ class RandomForestRegressor(BaseRandomForestModel, RegressorMixin): algo=algo, fil_sparse_format=fil_sparse_format) - """ - TODO : Move functions duplicated in the RF classifier and regressor - to a shared file. Cuml issue #1854 has been created to track this. - """ - + @generate_docstring() def fit(self, X, y, convert_dtype=True): """ Perform Random Forest Regression on the input data - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (int32) of shape (n_samples, 1). - Acceptable formats: NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - These labels should be contiguous integers from 0 to n_classes. - convert_dtype : bool, optional (default = True) - When set to True, the fit method will, when necessary, convert - y to be the same data type as X if they differ. This will increase - memory used for the method. """ X_m, y_m, max_feature_val = self._dataset_setup_for_fit(X, y, convert_dtype) @@ -503,6 +486,8 @@ class RandomForestRegressor(BaseRandomForestModel, RegressorMixin): del(X_m) return preds.to_output(out_type) + @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], + return_values=[('dense', '(n_samples, 1)')]) def predict(self, X, predict_model="GPU", algo='auto', convert_dtype=True, fil_sparse_format='auto'): @@ -511,10 +496,7 @@ class RandomForestRegressor(BaseRandomForestModel, RegressorMixin): Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy + X : {} predict_model : String (default = 'GPU') 'GPU' to predict using the GPU, 'CPU' otherwise. The GPU can only be used if the model was trained on float32 data and `X` is float32 @@ -546,8 +528,7 @@ class RandomForestRegressor(BaseRandomForestModel, RegressorMixin): Returns ---------- - y : NumPy - Dense vector (int) of shape (n_samples, 1) + y : {} """ if predict_model == "CPU": @@ -570,6 +551,8 @@ class RandomForestRegressor(BaseRandomForestModel, RegressorMixin): return preds + @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)'), + ('dense', '(n_samples, 1)')]) def score(self, X, y, algo='auto', convert_dtype=True, fil_sparse_format='auto', predict_model="GPU"): """ @@ -577,12 +560,8 @@ class RandomForestRegressor(BaseRandomForestModel, RegressorMixin): Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - y : NumPy - Dense vector (int) of shape (n_samples, 1) + X : {} + y : {} algo : string (default = 'auto') This is optional and required only while performing the predict operation on the GPU. diff --git a/python/cuml/linear_model/elastic_net.pyx b/python/cuml/linear_model/elastic_net.pyx index c4bea1d836..29bf32d1fb 100644 --- a/python/cuml/linear_model/elastic_net.pyx +++ b/python/cuml/linear_model/elastic_net.pyx @@ -21,6 +21,7 @@ from cuml.solvers import CD from cuml.common.base import Base, RegressorMixin +from cuml.common.doc_utils import generate_docstring class ElasticNet(Base, RegressorMixin): @@ -201,26 +202,11 @@ class ElasticNet(Base, RegressorMixin): msg = "l1_ratio value has to be between 0.0 and 1.0" raise ValueError(msg.format(l1_ratio)) + @generate_docstring() def fit(self, X, y, convert_dtype=True): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (floats or doubles) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the transform method will, when necessary, - convert y to be the same data type as X if they differ. This - will increase memory used for the method. """ self._set_n_features_in(X) self._set_output_type(X) @@ -229,26 +215,13 @@ class ElasticNet(Base, RegressorMixin): return self + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=False): """ - Predicts the y for X. - - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ---------- - y: cuDF DataFrame - Dense vector (floats or doubles) of shape (n_samples, 1) + Predicts `y` values for `X`. """ diff --git a/python/cuml/linear_model/lasso.pyx b/python/cuml/linear_model/lasso.pyx index f96c23c927..a14321acf3 100644 --- a/python/cuml/linear_model/lasso.pyx +++ b/python/cuml/linear_model/lasso.pyx @@ -21,6 +21,7 @@ from cuml.solvers import CD from cuml.common.base import Base, RegressorMixin +from cuml.common.doc_utils import generate_docstring class Lasso(Base, RegressorMixin): @@ -163,26 +164,11 @@ class Lasso(Base, RegressorMixin): msg = "alpha value has to be positive" raise ValueError(msg.format(alpha)) + @generate_docstring() def fit(self, X, y, convert_dtype=True): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (floats or doubles) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the transform method will, when necessary, - convert y to be the same data type as X if they differ. This - will increase memory used for the method. """ self._set_n_features_in(X) self._set_output_type(X) @@ -190,22 +176,14 @@ class Lasso(Base, RegressorMixin): return self + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=False): """ Predicts the y for X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - Returns - ---------- - y: cuDF DataFrame - Dense vector (floats or doubles) of shape (n_samples, 1) - """ return self.solver_model.predict(X, convert_dtype=convert_dtype) diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index ae73806858..5221aa3348 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -33,6 +33,7 @@ from libc.stdlib cimport calloc, malloc, free from cuml.common.array import CumlArray from cuml.common.base import Base, RegressorMixin +from cuml.common.doc_utils import generate_docstring from cuml.common.handle cimport cumlHandle from cuml.common import input_to_cuml_array @@ -211,26 +212,11 @@ class LinearRegression(Base, RegressorMixin): 'eig': 1 }[algorithm] + @generate_docstring() def fit(self, X, y, convert_dtype=True): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (floats or doubles) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the fit method will, when necessary, convert - y to be the same data type as X if they differ. This - will increase memory used for the method. """ self._set_n_features_in(X) self._set_output_type(X) @@ -303,27 +289,14 @@ class LinearRegression(Base, RegressorMixin): return self + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=False): """ Predicts `y` values for `X`. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ------- - y: cuDF DataFrame - Dense vector (floats or doubles) of shape (n_samples, 1) - """ out_type = self._get_output_type(X) diff --git a/python/cuml/linear_model/logistic_regression.pyx b/python/cuml/linear_model/logistic_regression.pyx index ebe33b7fb5..94b6de2503 100644 --- a/python/cuml/linear_model/logistic_regression.pyx +++ b/python/cuml/linear_model/logistic_regression.pyx @@ -25,6 +25,7 @@ import pprint from cuml.solvers import QN from cuml.common.base import Base, ClassifierMixin from cuml.common.array import CumlArray +from cuml.common.doc_utils import generate_docstring import cuml.common.logger as logger from cuml.common import input_to_cuml_array, with_cupy_rmm @@ -245,27 +246,12 @@ class LogisticRegression(Base, ClassifierMixin): else: self.verb_prefix = "" + @generate_docstring() @with_cupy_rmm def fit(self, X, y, convert_dtype=True): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (floats or doubles) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the fit method will, when necessary, convert - y to be the same data type as X if they differ. This - will increase memory used for the method. """ self.solver_model._set_target_dtype(y) self._set_output_type(X) @@ -312,76 +298,41 @@ class LogisticRegression(Base, ClassifierMixin): return self + @generate_docstring(return_values={'name': 'score', + 'type': 'dense', + 'description': 'Confidence score', + 'shape': '(n_samples, n_classes)'}) def decision_function(self, X, convert_dtype=False): """ Gives confidence score for X - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ---------- - y: array-like (device) - Dense matrix (floats or doubles) of shape (n_samples, n_classes) """ return self.solver_model._decision_function( X, convert_dtype=convert_dtype ).to_output(output_type=self._get_output_type(X)) + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=False): """ Predicts the y for X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ---------- - y : (same as the input datatype) - Dense vector (ints, floats, or doubles) of shape (n_samples, 1). """ return self.solver_model.predict(X, convert_dtype=convert_dtype) + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted class \ + probabilities', + 'shape': '(n_samples, n_classes)'}) @with_cupy_rmm def predict_proba(self, X, convert_dtype=False): """ Predicts the class probabilities for each class in X - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ---------- - y: array-like (device) - Dense matrix (floats or doubles) of shape (n_samples, n_classes) """ return self._predict_proba_impl( X, @@ -389,26 +340,15 @@ class LogisticRegression(Base, ClassifierMixin): log_proba=False ) + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Logaright of predicted \ + class probabilities', + 'shape': '(n_samples, n_classes)'}) def predict_log_proba(self, X, convert_dtype=False): """ Predicts the log class probabilities for each class in X - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ---------- - y: array-like (device) - Dense matrix (floats or doubles) of shape (n_samples, n_classes) """ return self._predict_proba_impl( X, diff --git a/python/cuml/linear_model/mbsgd_classifier.pyx b/python/cuml/linear_model/mbsgd_classifier.pyx index 5aefd638ec..9f0aa0394a 100644 --- a/python/cuml/linear_model/mbsgd_classifier.pyx +++ b/python/cuml/linear_model/mbsgd_classifier.pyx @@ -19,6 +19,7 @@ # cython: embedsignature = True # cython: language_level = 3 from cuml.common.base import Base, ClassifierMixin +from cuml.common.doc_utils import generate_docstring from cuml.solvers import SGD @@ -166,52 +167,25 @@ class MBSGDClassifier(Base, ClassifierMixin): self.n_iter_no_change = n_iter_no_change self.solver_model = SGD(**self.get_params()) + @generate_docstring() def fit(self, X, y, convert_dtype=True): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (floats or doubles) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the fit method will, when necessary, convert - y to be the same data type as X if they differ. This - will increase memory used for the method. """ self._set_n_features_in(X) self.solver_model._estimator_type = self._estimator_type self.solver_model.fit(X, y, convert_dtype=convert_dtype) return self + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=False): """ Predicts the y for X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ---------- - y : (same as the input datatype) - Dense vector (ints, floats, or doubles) of shape (n_samples, 1). """ preds = \ self.solver_model.predictClass(X, diff --git a/python/cuml/linear_model/mbsgd_regressor.pyx b/python/cuml/linear_model/mbsgd_regressor.pyx index f0ddb1d941..dc6ec5e66c 100644 --- a/python/cuml/linear_model/mbsgd_regressor.pyx +++ b/python/cuml/linear_model/mbsgd_regressor.pyx @@ -19,6 +19,7 @@ # cython: embedsignature = True # cython: language_level = 3 from cuml.common.base import Base, RegressorMixin +from cuml.common.doc_utils import generate_docstring from cuml.solvers import SGD @@ -162,51 +163,24 @@ class MBSGDRegressor(Base, RegressorMixin): self.n_iter_no_change = n_iter_no_change self.solver_model = SGD(**self.get_params()) + @generate_docstring() def fit(self, X, y, convert_dtype=True): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (floats or doubles) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the fit method will, when necessary, convert - y to be the same data type as X if they differ. This - will increase memory used for the method. """ self._set_n_features_in(X) self.solver_model.fit(X, y, convert_dtype=convert_dtype) return self + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=False): """ Predicts the y for X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ---------- - y: Type specified by `output_type` - Dense vector (floats or doubles) of shape (n_samples, 1) """ preds = self.solver_model.predict(X, diff --git a/python/cuml/linear_model/ridge.pyx b/python/cuml/linear_model/ridge.pyx index bff3056f75..19bb53b8e2 100644 --- a/python/cuml/linear_model/ridge.pyx +++ b/python/cuml/linear_model/ridge.pyx @@ -32,6 +32,7 @@ from libc.stdlib cimport calloc, malloc, free from cuml.common.base import Base, RegressorMixin from cuml.common.array import CumlArray +from cuml.common.doc_utils import generate_docstring from cuml.common.handle cimport cumlHandle from cuml.common import input_to_cuml_array @@ -242,26 +243,11 @@ class Ridge(Base, RegressorMixin): 'cd': 2 }[algorithm] + @generate_docstring() def fit(self, X, y, convert_dtype=True): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (floats or doubles) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the fit method will, when necessary, convert - y to be the same data type as X if they differ. This - will increase memory used for the method. """ self._set_output_type(X) self._set_n_features_in(X) @@ -344,27 +330,14 @@ class Ridge(Base, RegressorMixin): return self + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=False): """ Predicts the y for X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ---------- - y: cuDF DataFrame - Dense vector (floats or doubles) of shape (n_samples, 1) - """ out_type = self._get_output_type(X) diff --git a/python/cuml/manifold/t_sne.pyx b/python/cuml/manifold/t_sne.pyx index 6a1c80296a..5dcb298c6b 100644 --- a/python/cuml/manifold/t_sne.pyx +++ b/python/cuml/manifold/t_sne.pyx @@ -33,6 +33,7 @@ from cuml.common.handle cimport cumlHandle import cuml.common.logger as logger from cuml.common.array import CumlArray +from cuml.common.doc_utils import generate_docstring from cuml.common import input_to_cuml_array import rmm @@ -298,18 +299,11 @@ class TSNE(Base): self.pre_learning_rate = learning_rate self.post_learning_rate = learning_rate * 2 + @generate_docstring(convert_dtype_cast='np.float32') def fit(self, X, convert_dtype=True): - """Fit X into an embedded space. - - Parameters - ----------- - X : array-like (device or host) shape = (n_samples, n_features) - X contains a sample per row. - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - convert_dtype : bool, optional (default = True) - When set to True, the fit method will automatically - convert the inputs to np.float32. + """ + Fit X into an embedded space. + """ self._set_n_features_in(X) cdef int n, p @@ -407,23 +401,18 @@ class TSNE(Base): del self._embedding_ self._embedding_ = None + @generate_docstring(convert_dtype_cast='np.float32', + return_values={'name': 'X_new', + 'type': 'dense', + 'description': 'Embedding of the \ + training data in \ + low-dimensional space.', + 'shape': '(n_samples, n_components)'}) def fit_transform(self, X, convert_dtype=True): - """Fit X into an embedded space and return that transformed output. - - Parameters - ----------- - X : array-like (device or host) shape = (n_samples, n_features) - X contains a sample per row. - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - convert_dtype : bool, optional (default = True) - When set to True, the fit_transform method will automatically - convert the inputs to np.float32. - - Returns - -------- - X_new : array, shape (n_samples, n_components) - Embedding of the training data in low-dimensional space. + """ + Fit X into an embedded space and return that transformed output. + + """ self.fit(X, convert_dtype=convert_dtype) out_type = self._get_output_type(X) diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 7063b62daa..cb98cdff87 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -37,8 +37,10 @@ from cupy.sparse import csr_matrix as cp_csr_matrix,\ from cuml.common.base import Base from cuml.common.handle cimport cumlHandle -from cuml.common import get_cudf_column_ptr, get_dev_array_ptr, \ - input_to_cuml_array, zeros, with_cupy_rmm, has_scipy +from cuml.common.doc_utils import generate_docstring +from cuml.common.input_utils import input_to_cuml_array +from cuml.common.memory_utils import with_cupy_rmm +from cuml.common.import_utils import has_scipy from cuml.common.array import CumlArray import rmm @@ -487,6 +489,8 @@ class UMAP(Base): (knn_dists_m, knn_dists_m.ptr) return (None, None), (None, None) + @generate_docstring(convert_dtype_cast='np.float32', + skip_parameters_heading=True) @with_cupy_rmm def fit(self, X, y=None, convert_dtype=True, knn_graph=None): @@ -495,14 +499,6 @@ class UMAP(Base): Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - X contains a sample per row. - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - y : array-like (device or host) shape = (n_samples, 1) - y contains a label per row. - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy knn_graph : sparse array-like (device or host) shape=(n_samples, n_samples) A sparse array containing the k-nearest neighbors of X, @@ -602,6 +598,14 @@ class UMAP(Base): return self + @generate_docstring(convert_dtype_cast='np.float32', + skip_parameters_heading=True, + return_values={'name': 'X_new', + 'type': 'dense', + 'description': 'Embedding of the \ + data in \ + low-dimensional space.', + 'shape': '(n_samples, n_components)'}) def fit_transform(self, X, y=None, convert_dtype=True, knn_graph=None): """ @@ -617,10 +621,6 @@ class UMAP(Base): Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - X contains a sample per row. - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy knn_graph : sparse array-like (device or host) shape=(n_samples, n_samples) A sparse array containing the k-nearest neighbors of X, @@ -640,16 +640,20 @@ class UMAP(Base): Acceptable formats: sparse SciPy ndarray, CuPy device ndarray, CSR/COO preferred other formats will go through conversion to CSR - Returns - ------- - X_new : array, shape (n_samples, n_components) - Embedding of the training data in low-dimensional space. """ self.fit(X, y, convert_dtype=convert_dtype, knn_graph=knn_graph) out_type = self._get_output_type(X) return self._embedding_.to_output(out_type) + @generate_docstring(convert_dtype_cast='np.float32', + skip_parameters_heading=True, + return_values={'name': 'X_new', + 'type': 'dense', + 'description': 'Embedding of the \ + data in \ + low-dimensional space.', + 'shape': '(n_samples, n_components)'}) @with_cupy_rmm def transform(self, X, convert_dtype=True, knn_graph=None): @@ -666,10 +670,6 @@ class UMAP(Base): Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - New data to be transformed. - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy knn_graph : sparse array-like (device or host) shape=(n_samples, n_samples) A sparse array containing the k-nearest neighbors of X, @@ -689,10 +689,6 @@ class UMAP(Base): Acceptable formats: sparse SciPy ndarray, CuPy device ndarray, CSR/COO preferred other formats will go through conversion to CSR - Returns - ------- - X_new : array, shape (n_samples, n_components) - Embedding of the new data in low-dimensional space. """ if len(X.shape) != 2: raise ValueError("data should be two dimensional") diff --git a/python/cuml/naive_bayes/naive_bayes.py b/python/cuml/naive_bayes/naive_bayes.py index ba82465762..e06271a822 100644 --- a/python/cuml/naive_bayes/naive_bayes.py +++ b/python/cuml/naive_bayes/naive_bayes.py @@ -23,6 +23,7 @@ from cuml.common import with_cupy_rmm from cuml.common import CumlArray from cuml.common.base import Base +from cuml.common.doc_utils import generate_docstring from cuml.common.input_utils import input_to_cuml_array from cuml.common.kernel_utils import cuda_kernel_factory from cuml.common.import_utils import has_scipy @@ -115,9 +116,6 @@ def count_features_dense_kernel(float_dtype, int_dtype): class MultinomialNB(Base): - # TODO: Make this extend cuml.Base: - # https://github.com/rapidsai/cuml/issues/1834 - """ Naive Bayes classifier for multinomial models @@ -216,21 +214,13 @@ def __init__(self, # Needed until Base no longer assumed cumlHandle self.handle = None + @generate_docstring(X='dense_sparse') @cp.prof.TimeRangeDecorator(message="fit()", color_id=0) @with_cupy_rmm def fit(self, X, y, sample_weight=None): """ Fit Naive Bayes classifier according to X, y - Parameters - ---------- - - X : {array-like, cupy sparse matrix} of shape (n_samples, n_features) - Training vectors, where n_samples is the number of samples and - n_features is the number of features. - y : array-like shape (n_samples) Target values. - sample_weight : array-like of shape (n_samples) - Weights applied to individial samples (1. for unweighted). """ self._set_n_features_in(X) return self.partial_fit(X, y, sample_weight) @@ -336,22 +326,17 @@ def partial_fit(self, X, y, classes=None, sample_weight=None): return self._partial_fit(X, y, sample_weight=sample_weight, _classes=classes) + @generate_docstring(X='dense_sparse', + return_values={'name': 'y_hat', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_rows, 1)'}) @cp.prof.TimeRangeDecorator(message="predict()", color_id=1) @with_cupy_rmm def predict(self, X): """ Perform classification on an array of test vectors X. - Parameters - ---------- - - X : array-like of shape (n_samples, n_features) - - Returns - ------- - - C : cupy.ndarray of shape (n_samples) - """ out_type = self._get_output_type(X) @@ -378,24 +363,18 @@ def predict(self, X): y_hat = invert_labels(indices, classes=self.classes_) return CumlArray(data=y_hat).to_output(out_type) + @generate_docstring(X='dense_sparse', + return_values={'name': 'C', + 'type': 'dense', + 'description': 'Returns the log-probability of the samples for each class in the \ + model. The columns correspond to the classes in sorted order, as \ + they appear in the attribute `classes_`.', # noqa + 'shape': '(n_rows, 1)'}) @with_cupy_rmm def predict_log_proba(self, X): """ Return log-probability estimates for the test vector X. - Parameters - ---------- - - X : array-like of shape (n_samples, n_features) - - - Returns - ------- - - C : array-like of shape (n_samples, n_classes) - Returns the log-probability of the samples for each class in the - model. The columns correspond to the classes in sorted order, as - they appear in the attribute classes_. """ out_type = self._get_output_type(X) @@ -437,28 +416,28 @@ def predict_log_proba(self, X): result = jll - log_prob_x.T return CumlArray(result).to_output(out_type) + @generate_docstring(X='dense_sparse', + return_values={'name': 'C', + 'type': 'dense', + 'description': 'Returns the probability of the samples for each class in the \ + model. The columns correspond to the classes in sorted order, as \ + they appear in the attribute `classes_`.', # noqa + 'shape': '(n_rows, 1)'}) @with_cupy_rmm def predict_proba(self, X): """ Return probability estimates for the test vector X. - Parameters - ---------- - - X : array-like of shape (n_samples, n_features) - - Returns - ------- - - C : array-like of shape (n_samples, n_classes) - Returns the probability of the samples for each class in the model. - The columns correspond to the classes in sorted order, as they - appear in the attribute classes_. """ out_type = self._get_output_type(X) result = cp.exp(self.predict_log_proba(X)) return CumlArray(result).to_output(out_type) + @generate_docstring(X='dense_sparse', + return_values={'name': 'score', + 'type': 'float', + 'description': 'Mean accuracy of \ + self.predict(X) with respect to y.'}) @with_cupy_rmm def score(self, X, y, sample_weight=None): """ @@ -468,21 +447,8 @@ def score(self, X, y, sample_weight=None): harsh metric since you require for each sample that each label set be correctly predicted. - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - Test samples. - - y : array-like of shape (n_samples,) or (n_samples, n_outputs) - True labels for X. - - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. Currently, sample weight is ignored - - Returns - ------- + Currently, sample weight is ignored - score : float Mean accuracy of self.predict(X) with respect to y. """ y_hat = self.predict(X) return accuracy_score(y_hat, cp.asarray(y, dtype=y.dtype)) diff --git a/python/cuml/neighbors/kneighbors_classifier.pyx b/python/cuml/neighbors/kneighbors_classifier.pyx index 189b905348..ad2c877199 100644 --- a/python/cuml/neighbors/kneighbors_classifier.pyx +++ b/python/cuml/neighbors/kneighbors_classifier.pyx @@ -24,6 +24,7 @@ from cuml.neighbors.nearest_neighbors import NearestNeighbors from cuml.common.array import CumlArray from cuml.common import input_to_cuml_array from cuml.common.base import ClassifierMixin +from cuml.common.doc_utils import generate_docstring import numpy as np import cupy as cp @@ -145,26 +146,12 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): raise ValueError("Only uniform weighting strategy is " "supported currently.") + @generate_docstring(convert_dtype_cast='np.float32') @with_cupy_rmm def fit(self, X, y, convert_dtype=True): """ Fit a GPU index for k-nearest neighbors classifier model. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, n_outputs) - Dense matrix (floats or doubles) of shape (n_samples, n_outputs). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the fit method will automatically - convert the inputs to np.float32. """ self._set_target_dtype(y) @@ -177,25 +164,16 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): self._classes_ = CumlArray(cp.unique(self._y)) return self + @generate_docstring(convert_dtype_cast='np.float32', + return_values={'name': 'X_new', + 'type': 'dense', + 'description': 'Labels predicted', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=True): """ Use the trained k-nearest neighbors classifier to predict the labels for X - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - convert_dtype : bool, optional (default = True) - When set to True, the fit method will automatically - convert the inputs to np.float32. - - Returns - ---------- - y : (same as the input datatype) - Dense vector (ints, floats, or doubles) of shape (n_samples, 1). """ out_type = self._get_output_type(X) @@ -245,21 +223,17 @@ class KNeighborsClassifier(NearestNeighbors, ClassifierMixin): return classes.to_output(output_type=out_type, output_dtype=out_dtype) + @generate_docstring(convert_dtype_cast='np.float32', + return_values={'name': 'X_new', + 'type': 'dense', + 'description': 'Labels probabilities', + 'shape': '(n_samples, 1)'}) @with_cupy_rmm def predict_proba(self, X, convert_dtype=True): """ Use the trained k-nearest neighbors classifier to predict the label probabilities for X - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - convert_dtype : bool, optional (default = True) - When set to True, the fit method will automatically - convert the inputs to np.float32. """ out_type = self._get_output_type(X) diff --git a/python/cuml/neighbors/kneighbors_regressor.pyx b/python/cuml/neighbors/kneighbors_regressor.pyx index 074b6cb5a2..ec619f3245 100644 --- a/python/cuml/neighbors/kneighbors_regressor.pyx +++ b/python/cuml/neighbors/kneighbors_regressor.pyx @@ -24,6 +24,7 @@ from cuml.neighbors.nearest_neighbors import NearestNeighbors from cuml.common.array import CumlArray from cuml.common import input_to_cuml_array from cuml.common.base import RegressorMixin +from cuml.common.doc_utils import generate_docstring import numpy as np @@ -152,25 +153,11 @@ class KNeighborsRegressor(NearestNeighbors, RegressorMixin): raise ValueError("Only uniform weighting strategy " "is supported currently.") + @generate_docstring(convert_dtype_cast='np.float32') def fit(self, X, y, convert_dtype=True): """ Fit a GPU index for k-nearest neighbors regression model. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, n_outputs) - Dense matrix (floats or doubles) of shape (n_samples, n_outputs). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the fit method will automatically - convert the inputs to np.float32. """ super(KNeighborsRegressor, self).fit(X, convert_dtype=convert_dtype) self._y, _, _, _ = \ @@ -180,21 +167,16 @@ class KNeighborsRegressor(NearestNeighbors, RegressorMixin): else None)) return self + @generate_docstring(convert_dtype_cast='np.float32', + return_values={'name': 'X_new', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, n_features)'}) def predict(self, X, convert_dtype=True): """ Use the trained k-nearest neighbors regression model to predict the labels for X - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the fit method will automatically - convert the inputs to np.float32. """ out_type = self._get_output_type(X) diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index 5ea0dbbb49..ba3a32e95d 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -20,7 +20,6 @@ # cython: language_level = 3 import numpy as np -import pandas as pd import cupy as cp import cudf import ctypes @@ -29,6 +28,8 @@ import warnings from cuml.common.base import Base from cuml.common.array import CumlArray +from cuml.common.doc_utils import generate_docstring +from cuml.common.doc_utils import insert_into_docstring from cuml.common import input_to_cuml_array from cython.operator cimport dereference as deref @@ -219,20 +220,11 @@ class NearestNeighbors(Base): self.p = p self.algorithm = algorithm + @generate_docstring() def fit(self, X, convert_dtype=True): """ Fit GPU index for performing nearest neighbor queries. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the fit method will automatically - convert the inputs to np.float32. """ self._set_n_features_in(X) self._set_output_type(X) @@ -289,6 +281,10 @@ class NearestNeighbors(Base): return m, expanded + @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], + return_values=[('dense', '(n_samples, n_features)'), + ('dense', + '(n_samples, n_features)')]) def kneighbors(self, X=None, n_neighbors=None, return_distance=True, convert_dtype=True): """ @@ -296,10 +292,7 @@ class NearestNeighbors(Base): Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy + X : {} n_neighbors : Integer Number of neighbors to search. If not provided, the n_neighbors @@ -314,11 +307,11 @@ class NearestNeighbors(Base): Returns ------- - distances: cuDF DataFrame, pandas DataFrame, numpy or cupy ndarray + distances : {} The distances of the k-nearest neighbors for each column vector in X - indices: cuDF DataFrame, pandas DataFrame, numpy or cupy ndarray + indices : {} The indices of the k-nearest neighbors for each column vector in X """ @@ -452,6 +445,7 @@ class NearestNeighbors(Base): return (D_output, I_output) if return_distance else I_output + @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')]) def kneighbors_graph(self, X=None, n_neighbors=None, mode='connectivity'): """ Find the k nearest neighbors of column vectors in X and return as @@ -459,10 +453,7 @@ class NearestNeighbors(Base): Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy + X : {} n_neighbors : Integer Number of neighbors to search. If not provided, the n_neighbors diff --git a/python/cuml/solvers/cd.pyx b/python/cuml/solvers/cd.pyx index 8ce1695262..3fec2d1a42 100644 --- a/python/cuml/solvers/cd.pyx +++ b/python/cuml/solvers/cd.pyx @@ -30,6 +30,7 @@ from libc.stdlib cimport calloc, malloc, free from cuml.common.array import CumlArray from cuml.common.base import Base +from cuml.common.doc_utils import generate_docstring from cuml.common.handle cimport cumlHandle from cuml.common import get_cudf_column_ptr from cuml.common import get_dev_array_ptr @@ -208,26 +209,11 @@ class CD(Base): 'squared_loss': 0, }[loss] + @generate_docstring() def fit(self, X, y, convert_dtype=False): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (floats or doubles) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the fit method will, when necessary, convert - y to be the same data type as X if they differ. This - will increase memory used for the method. """ self._set_output_type(X) @@ -294,26 +280,14 @@ class CD(Base): return self + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=False): """ Predicts the y for X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ---------- - y: cuDF DataFrame - Dense vector (floats or doubles) of shape (n_samples, 1) """ out_type = self._get_output_type(X) diff --git a/python/cuml/solvers/qn.pyx b/python/cuml/solvers/qn.pyx index 04fa634db6..ff18679f82 100644 --- a/python/cuml/solvers/qn.pyx +++ b/python/cuml/solvers/qn.pyx @@ -25,7 +25,9 @@ import numpy as np from libcpp cimport bool from libc.stdint cimport uintptr_t -from cuml.common.base import Base, CumlArray +from cuml.common.array import CumlArray +from cuml.common.base import Base +from cuml.common.doc_utils import generate_docstring from cuml.common.handle cimport cumlHandle from cuml.common import input_to_cuml_array from cuml.common import with_cupy_rmm @@ -263,27 +265,12 @@ class QN(Base): 'normal': 1 }[loss] + @generate_docstring() @with_cupy_rmm def fit(self, X, y, convert_dtype=False): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (floats or doubles) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the fit method will, when necessary, convert - y to be the same data type as X if they differ. This - will increase memory used for the method. """ self._set_output_type(X) @@ -450,26 +437,14 @@ class QN(Base): return scores + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=False): """ Predicts the y for X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ---------- - y: cuDF DataFrame - Dense vector (floats or doubles) of shape (n_samples, 1) """ out_type = self._get_output_type(X) out_dtype = self._get_target_dtype() diff --git a/python/cuml/solvers/sgd.pyx b/python/cuml/solvers/sgd.pyx index 47e2fe02ad..110745d68f 100644 --- a/python/cuml/solvers/sgd.pyx +++ b/python/cuml/solvers/sgd.pyx @@ -30,7 +30,8 @@ from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free from cuml.common.base import Base -from cuml.common import CumlArray +from cuml.common.array import CumlArray +from cuml.common.doc_utils import generate_docstring from cuml.common.handle cimport cumlHandle from cuml.common import input_to_cuml_array, with_cupy_rmm @@ -294,27 +295,12 @@ class SGD(Base): 'elasticnet': 3 }[penalty] + @generate_docstring() @with_cupy_rmm def fit(self, X, y, convert_dtype=False): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (floats or doubles) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the fit method will, when necessary, convert - y to be the same data type as X if they differ. This - will increase memory used for the method. """ self._set_output_type(X) self._set_target_dtype(y) @@ -399,26 +385,14 @@ class SGD(Base): return self + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X, convert_dtype=False): """ Predicts the y for X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predict method will, when necessary, convert - the input to the data type which was used to train the model. This - will increase memory used for the method. - - Returns - ---------- - y: Type specified in `output_type` - Dense vector (floats or doubles) of shape (n_samples, 1) """ output_type = self._get_output_type(X) @@ -461,26 +435,14 @@ class SGD(Base): return preds.to_output(output_type) + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predictClass(self, X, convert_dtype=False): """ Predicts the y for X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = False) - When set to True, the predictClass method will automatically - convert the input to the data type which was used to train the - model. This will increase memory used for the method. - - Returns - ---------- - y : Type specified in `output_type` - Dense vector (floats or doubles) of shape (n_samples, 1) """ output_type = self._get_output_type(X) out_dtype = self._get_target_dtype() diff --git a/python/cuml/svm/svc.pyx b/python/cuml/svm/svc.pyx index 3302170f9f..ac2b8959d5 100644 --- a/python/cuml/svm/svc.pyx +++ b/python/cuml/svm/svc.pyx @@ -30,6 +30,7 @@ from libc.stdint cimport uintptr_t from cuml.common.array import CumlArray from cuml.common.base import Base, ClassifierMixin +from cuml.common.doc_utils import generate_docstring from cuml.common.logger import warn from cuml.common.handle cimport cumlHandle from cuml.common import input_to_cuml_array, input_to_host_array, with_cupy_rmm @@ -318,27 +319,12 @@ class SVC(SVMBase, ClassifierMixin): return sample_weight + @generate_docstring(y='dense_anydtype') @with_cupy_rmm def fit(self, X, y, sample_weight=None, convert_dtype=True): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (any numeric type) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the fit method will, when necessary, convert - y to be the same data type as X if they differ. This - will increase memory used for the method. """ self._set_n_features_in(X) self._set_output_type(X) @@ -412,22 +398,14 @@ class SVC(SVMBase, ClassifierMixin): return self + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X): """ Predicts the class labels for X. The returned y values are the class labels associated to sign(decision_function(X)). - - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - Returns - ------- - y : (same as the input datatype) - Dense vector (ints, floats, or doubles) of shape (n_samples, 1). """ if self.probability: @@ -440,6 +418,12 @@ class SVC(SVMBase, ClassifierMixin): else: return super(SVC, self).predict(X, True) + @generate_docstring(skip_parameters_heading=True, + return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted \ + probabilities', + 'shape': '(n_samples, n_classes)'}) def predict_proba(self, X, log=False): """ Predicts the class probabilities for X. @@ -448,19 +432,9 @@ class SVC(SVMBase, ClassifierMixin): Parameters ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of input features. - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - log: boolean (default = False) Whether to return log probabilities. - Returns - ------- - P : array-like (device or host) shape = (n_samples, n_classes) - Dense matrix of classs probabilities for each sample. - """ if self.probability: @@ -477,42 +451,29 @@ class SVC(SVMBase, ClassifierMixin): "probabilities. Fit a new classifier with" "probability=True to enable predict_proba.") + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Log of predicted \ + probabilities', + 'shape': '(n_samples, n_classes)'}) def predict_log_proba(self, X): """ Predicts the log probabilities for X (returns log(predict_proba(x)). The model has to be trained with probability=True to use this method. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of input features. - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy. - - Returns - ------- - P : array-like (device or host) shape = (n_samples, n_classes) - Dense matrix of log probabilities for each sample. - """ return self.predict_proba(X, log=True) + @generate_docstring(return_values={'name': 'results', + 'type': 'dense', + 'description': 'Decision function \ + values', + 'shape': '(n_samples, 1)'}) def decision_function(self, X): """ Calculates the decision function values for X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - Returns - ------- - y : cuDF Series - Dense vector (floats or doubles) of shape (n_samples, 1) """ if self.probability: self._check_is_fitted('prob_svc') diff --git a/python/cuml/svm/svr.pyx b/python/cuml/svm/svr.pyx index e3604d9cfa..c85f8b360a 100644 --- a/python/cuml/svm/svr.pyx +++ b/python/cuml/svm/svr.pyx @@ -29,7 +29,9 @@ from cython.operator cimport dereference as deref from libc.stdint cimport uintptr_t from cuml.common.array import CumlArray -from cuml.common.base import Base, RegressorMixin +from cuml.common.base import Base +from cuml.common.base import RegressorMixin +from cuml.common.doc_utils import generate_docstring from cuml.metrics import r2_score from cuml.common.handle cimport cumlHandle from cuml.common import input_to_cuml_array @@ -223,26 +225,11 @@ class SVR(SVMBase, RegressorMixin): verbose, epsilon) self.svmType = EPSILON_SVR + @generate_docstring() def fit(self, X, y, sample_weight=None, convert_dtype=True): """ Fit the model with X and y. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - y : array-like (device or host) shape = (n_samples, 1) - Dense vector (floats or doubles) of shape (n_samples, 1). - Acceptable formats: cuDF Series, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - convert_dtype : bool, optional (default = True) - When set to True, the fit method will, when necessary, convert - y to be the same data type as X if they differ. This - will increase memory used for the method. """ self._set_n_features_in(X) self._set_output_type(X) @@ -301,21 +288,14 @@ class SVR(SVMBase, RegressorMixin): return self + @generate_docstring(return_values={'name': 'preds', + 'type': 'dense', + 'description': 'Predicted values', + 'shape': '(n_samples, 1)'}) def predict(self, X): """ Predicts the values for X. - Parameters - ---------- - X : array-like (device or host) shape = (n_samples, n_features) - Dense matrix (floats or doubles) of shape (n_samples, n_features). - Acceptable formats: cuDF DataFrame, NumPy ndarray, Numba device - ndarray, cuda array interface compliant array like CuPy - - Returns - ------- - y : cuDF Series - Dense vector (floats or doubles) of shape (n_samples, 1) """ return super(SVR, self).predict(X, False) diff --git a/python/cuml/test/test_fit_function.py b/python/cuml/test/test_fit_function.py index 0c85dc89dd..6a0059dab5 100644 --- a/python/cuml/test/test_fit_function.py +++ b/python/cuml/test/test_fit_function.py @@ -36,7 +36,8 @@ def test_fit_function(dataset, model_name): "SparseRandomProjection", "TSNE", "TruncatedSVD", - "AutoARIMA" + "AutoARIMA", + "MultinomialNB" ]: pytest.xfail("These models are not tested yet") @@ -59,7 +60,7 @@ def test_fit_function(dataset, model_name): # and the inspect module doesn't work with Cython. Therefore we need # to register the number of arguments manually if `fit` is decorated pos_args_spec = { - "ARIMA": 1, + "ARIMA": 1 } n_pos_args_fit = ( pos_args_spec[model_name]