From ea8e47ea24fb942155f16da6720b32de976d5731 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 29 Jan 2021 11:59:21 -0600 Subject: [PATCH] [dask] Add type hints in Dask package (#3866) * add type hints in dask module * starting on asserts * remove unused code * add hints for dtypes * replace accidentally-removed docstrings * revert unrelated change * Update python-package/lightgbm/dask.py * empty commit * fix hints on group * capitalize array * hide hints in signatures * empty commit * sphinx version * Apply suggestions from code review Co-authored-by: Nikita Titov * fix hint for MatrixLike * Update python-package/lightgbm/dask.py Co-authored-by: Nikita Titov * Apply suggestions from code review Co-authored-by: Nikita Titov * update docstring * empty commit Co-authored-by: Nikita Titov --- docs/conf.py | 5 +- python-package/lightgbm/compat.py | 12 ++- python-package/lightgbm/dask.py | 138 ++++++++++++++++++++++-------- 3 files changed, 117 insertions(+), 38 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index a66be6df8ccf..4a239c308621 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -74,7 +74,7 @@ def run(self): RTD = bool(os.environ.get('READTHEDOCS', '')) # If your documentation needs a minimal Sphinx version, state it here. -needs_sphinx = '1.3' # Due to sphinx.ext.napoleon +needs_sphinx = '2.1.0' # Due to sphinx.ext.napoleon, autodoc_typehints if needs_sphinx > sphinx.__version__: message = 'This project needs at least Sphinx v%s' % needs_sphinx raise VersionRequirementError(message) @@ -97,6 +97,9 @@ def run(self): "show-inheritance": True, } +# hide type hints in API docs +autodoc_typehints = "none" + # Generate autosummary pages. Output should be set with: `:toctree: pythonapi/` autosummary_generate = ['Python-API.rst'] diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index e800bb7b4795..350afcd0834f 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -113,7 +113,8 @@ def _check_sample_weight(sample_weight, X, dtype=None): try: from dask import delayed from dask.array import Array as dask_Array - from dask.dataframe import _Frame as dask_Frame + from dask.dataframe import DataFrame as dask_DataFrame + from dask.dataframe import Series as dask_Series from dask.distributed import Client, default_client, get_worker, wait DASK_INSTALLED = True except ImportError: @@ -129,7 +130,12 @@ class dask_Array: pass - class dask_Frame: - """Dummy class for dask.dataframe._Frame.""" + class dask_DataFrame: + """Dummy class for dask.dataframe.DataFrame.""" + + pass + + class dask_Series: + """Dummy class for dask.dataframe.Series.""" pass diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 218d8fa1cd89..c5f4049b0d5f 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -9,7 +9,7 @@ import socket from collections import defaultdict from copy import deepcopy -from typing import Dict, Iterable +from typing import Any, Dict, Iterable, List, Optional, Type, Union from urllib.parse import urlparse import numpy as np @@ -18,8 +18,13 @@ from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat, SKLEARN_INSTALLED, - DASK_INSTALLED, dask_Frame, dask_Array, delayed, Client, default_client, get_worker, wait) -from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker + DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait) +from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker + +_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series] +_DaskMatrixLike = Union[dask_Array, dask_DataFrame] +_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix] +_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]] def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int: @@ -102,7 +107,7 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc return worker_ip_to_port -def _concat(seq): +def _concat(seq: List[_DaskPart]) -> _DaskPart: if isinstance(seq[0], np.ndarray): return np.concatenate(seq, axis=0) elif isinstance(seq[0], (pd_DataFrame, pd_Series)): @@ -113,8 +118,15 @@ def _concat(seq): raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0]))) -def _train_part(params, model_factory, list_of_parts, worker_address_to_port, return_model, - time_out=120, **kwargs): +def _train_part( + params: Dict[str, Any], + model_factory: Type[LGBMModel], + list_of_parts: List[Dict[str, _DaskPart]], + worker_address_to_port: Dict[str, int], + return_model: bool, + time_out: int = 120, + **kwargs: Any +) -> Optional[LGBMModel]: local_worker_address = get_worker().address machine_list = ','.join([ '%s:%d' % (urlparse(worker_address).hostname, port) @@ -158,7 +170,7 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re return model if return_model else None -def _split_to_parts(data, is_matrix): +def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]: parts = data.to_delayed() if isinstance(parts, np.ndarray): if is_matrix: @@ -169,24 +181,33 @@ def _split_to_parts(data, is_matrix): return parts -def _train(client, data, label, params, model_factory, sample_weight=None, group=None, **kwargs): +def _train( + client: Client, + data: _DaskMatrixLike, + label: _DaskCollection, + params: Dict[str, Any], + model_factory: Type[LGBMModel], + sample_weight: Optional[_DaskCollection] = None, + group: Optional[_DaskCollection] = None, + **kwargs: Any +) -> LGBMModel: """Inner train routine. Parameters ---------- client : dask.distributed.Client Dask client. - data : dask array of shape = [n_samples, n_features] + data : dask Array or dask DataFrame of shape = [n_samples, n_features] Input feature matrix. - label : dask array of shape = [n_samples] + label : dask Array, dask DataFrame or dask Series of shape = [n_samples] The target values (class labels in classification, real numbers in regression). params : dict Parameters passed to constructor of the local underlying model. model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class Class of the local underlying model. - sample_weight : array-like of shape = [n_samples] or None, optional (default=None) + sample_weight : dask Array, dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None) Weights of training data. - group : array-like or None, optional (default=None) + group : dask Array, dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None) Group/query data. Only used in the learning-to-rank task. sum(group) = n_samples. @@ -301,7 +322,15 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group return results[0] -def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs): +def _predict_part( + part: _DaskPart, + model: LGBMModel, + raw_score: bool, + pred_proba: bool, + pred_leaf: bool, + pred_contrib: bool, + **kwargs: Any +) -> _DaskPart: data = part.values if isinstance(part, pd_DataFrame) else part if data.shape[0] == 0: @@ -332,15 +361,23 @@ def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, * return result -def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pred_contrib=False, - dtype=np.float32, **kwargs): +def _predict( + model: LGBMModel, + data: _DaskMatrixLike, + raw_score: bool = False, + pred_proba: bool = False, + pred_leaf: bool = False, + pred_contrib: bool = False, + dtype: _PredictionDtype = np.float32, + **kwargs: Any +) -> dask_Array: """Inner predict routine. Parameters ---------- model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class Fitted underlying model. - data : dask array of shape = [n_samples, n_features] + data : dask Array or dask DataFrame of shape = [n_samples, n_features] Input feature matrix. raw_score : bool, optional (default=False) Whether to predict raw scores. @@ -357,16 +394,16 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr Returns ------- - predicted_result : dask array of shape = [n_samples] or shape = [n_samples, n_classes] + predicted_result : dask Array of shape = [n_samples] or shape = [n_samples, n_classes] The predicted values. - X_leaves : dask array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes] + X_leaves : dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes] If ``pred_leaf=True``, the predicted leaf of every tree for each sample. - X_SHAP_values : dask array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects + X_SHAP_values : dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] If ``pred_contrib=True``, the feature contributions for each sample. """ if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') - if isinstance(data, dask_Frame): + if isinstance(data, dask_DataFrame): return data.map_partitions( _predict_part, model=model, @@ -392,11 +429,21 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr **kwargs ) else: - raise TypeError('Data must be either Dask array or dataframe. Got %s.' % str(type(data))) + raise TypeError('Data must be either dask Array or dask DataFrame. Got %s.' % str(type(data))) class _DaskLGBMModel: - def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None, **kwargs): + + def _fit( + self, + model_factory: Type[LGBMModel], + X: _DaskMatrixLike, + y: _DaskCollection, + sample_weight: Optional[_DaskCollection] = None, + group: Optional[_DaskCollection] = None, + client: Optional[Client] = None, + **kwargs: Any + ) -> "_DaskLGBMModel": if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') if client is None: @@ -420,13 +467,13 @@ def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None, return self - def _to_local(self, model_factory): + def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel: model = model_factory(**self.get_params()) self._copy_extra_params(self, model) return model @staticmethod - def _copy_extra_params(source, dest): + def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None: params = source.get_params() attributes = source.__dict__ extra_param_names = set(attributes.keys()).difference(params.keys()) @@ -437,7 +484,14 @@ def _copy_extra_params(source, dest): class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): """Distributed version of lightgbm.LGBMClassifier.""" - def fit(self, X, y, sample_weight=None, client=None, **kwargs): + def fit( + self, + X: _DaskMatrixLike, + y: _DaskCollection, + sample_weight: Optional[_DaskCollection] = None, + client: Optional[Client] = None, + **kwargs: Any + ) -> "DaskLGBMClassifier": """Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" return self._fit( model_factory=LGBMClassifier, @@ -455,7 +509,7 @@ def fit(self, X, y, sample_weight=None, client=None, **kwargs): + ' ' * 12 + 'Dask client.\n' + ' ' * 8 + _init_score + _after_init_score) - def predict(self, X, **kwargs): + def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" return _predict( model=self.to_local(), @@ -466,7 +520,7 @@ def predict(self, X, **kwargs): predict.__doc__ = LGBMClassifier.predict.__doc__ - def predict_proba(self, X, **kwargs): + def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba.""" return _predict( model=self.to_local(), @@ -477,7 +531,7 @@ def predict_proba(self, X, **kwargs): predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__ - def to_local(self): + def to_local(self) -> LGBMClassifier: """Create regular version of lightgbm.LGBMClassifier from the distributed version. Returns @@ -491,7 +545,14 @@ def to_local(self): class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): """Distributed version of lightgbm.LGBMRegressor.""" - def fit(self, X, y, sample_weight=None, client=None, **kwargs): + def fit( + self, + X: _DaskMatrixLike, + y: _DaskCollection, + sample_weight: Optional[_DaskCollection] = None, + client: Optional[Client] = None, + **kwargs: Any + ) -> "DaskLGBMRegressor": """Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" return self._fit( model_factory=LGBMRegressor, @@ -509,7 +570,7 @@ def fit(self, X, y, sample_weight=None, client=None, **kwargs): + ' ' * 12 + 'Dask client.\n' + ' ' * 8 + _init_score + _after_init_score) - def predict(self, X, **kwargs): + def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" return _predict( model=self.to_local(), @@ -519,7 +580,7 @@ def predict(self, X, **kwargs): predict.__doc__ = LGBMRegressor.predict.__doc__ - def to_local(self): + def to_local(self) -> LGBMRegressor: """Create regular version of lightgbm.LGBMRegressor from the distributed version. Returns @@ -533,7 +594,16 @@ def to_local(self): class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): """Distributed version of lightgbm.LGBMRanker.""" - def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None, **kwargs): + def fit( + self, + X: _DaskMatrixLike, + y: _DaskCollection, + sample_weight: Optional[_DaskCollection] = None, + init_score: Optional[_DaskCollection] = None, + group: Optional[_DaskCollection] = None, + client: Optional[Client] = None, + **kwargs: Any + ) -> "DaskLGBMRanker": """Docstring is inherited from the lightgbm.LGBMRanker.fit.""" if init_score is not None: raise RuntimeError('init_score is not currently supported in lightgbm.dask') @@ -555,13 +625,13 @@ def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None + ' ' * 12 + 'Dask client.\n' + ' ' * 8 + _eval_set + _after_eval_set) - def predict(self, X, **kwargs): + def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array: """Docstring is inherited from the lightgbm.LGBMRanker.predict.""" return _predict(self.to_local(), X, **kwargs) predict.__doc__ = LGBMRanker.predict.__doc__ - def to_local(self): + def to_local(self) -> LGBMRanker: """Create regular version of lightgbm.LGBMRanker from the distributed version. Returns