Skip to content

Commit

Permalink
[dask] Add type hints in Dask package (#3866)
Browse files Browse the repository at this point in the history
* add type hints in dask module

* starting on asserts

* remove unused code

* add hints for dtypes

* replace accidentally-removed docstrings

* revert unrelated change

* Update python-package/lightgbm/dask.py

* empty commit

* fix hints on group

* capitalize array

* hide hints in signatures

* empty commit

* sphinx version

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* fix hint for MatrixLike

* Update python-package/lightgbm/dask.py

Co-authored-by: Nikita Titov <[email protected]>

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* update docstring

* empty commit

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
jameslamb and StrikerRUS authored Jan 29, 2021
1 parent 217642c commit ea8e47e
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 38 deletions.
5 changes: 4 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run(self):
RTD = bool(os.environ.get('READTHEDOCS', ''))

# If your documentation needs a minimal Sphinx version, state it here.
needs_sphinx = '1.3' # Due to sphinx.ext.napoleon
needs_sphinx = '2.1.0' # Due to sphinx.ext.napoleon, autodoc_typehints
if needs_sphinx > sphinx.__version__:
message = 'This project needs at least Sphinx v%s' % needs_sphinx
raise VersionRequirementError(message)
Expand All @@ -97,6 +97,9 @@ def run(self):
"show-inheritance": True,
}

# hide type hints in API docs
autodoc_typehints = "none"

# Generate autosummary pages. Output should be set with: `:toctree: pythonapi/`
autosummary_generate = ['Python-API.rst']

Expand Down
12 changes: 9 additions & 3 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ def _check_sample_weight(sample_weight, X, dtype=None):
try:
from dask import delayed
from dask.array import Array as dask_Array
from dask.dataframe import _Frame as dask_Frame
from dask.dataframe import DataFrame as dask_DataFrame
from dask.dataframe import Series as dask_Series
from dask.distributed import Client, default_client, get_worker, wait
DASK_INSTALLED = True
except ImportError:
Expand All @@ -129,7 +130,12 @@ class dask_Array:

pass

class dask_Frame:
"""Dummy class for dask.dataframe._Frame."""
class dask_DataFrame:
"""Dummy class for dask.dataframe.DataFrame."""

pass

class dask_Series:
"""Dummy class for dask.dataframe.Series."""

pass
138 changes: 104 additions & 34 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import socket
from collections import defaultdict
from copy import deepcopy
from typing import Dict, Iterable
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from urllib.parse import urlparse

import numpy as np
Expand All @@ -18,8 +18,13 @@
from .basic import _choose_param_value, _ConfigAliases, _LIB, _log_warning, _safe_call, LightGBMError
from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat,
SKLEARN_INSTALLED,
DASK_INSTALLED, dask_Frame, dask_Array, delayed, Client, default_client, get_worker, wait)
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker
DASK_INSTALLED, dask_DataFrame, dask_Array, dask_Series, delayed, Client, default_client, get_worker, wait)
from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker

_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]


def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
Expand Down Expand Up @@ -102,7 +107,7 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc
return worker_ip_to_port


def _concat(seq):
def _concat(seq: List[_DaskPart]) -> _DaskPart:
if isinstance(seq[0], np.ndarray):
return np.concatenate(seq, axis=0)
elif isinstance(seq[0], (pd_DataFrame, pd_Series)):
Expand All @@ -113,8 +118,15 @@ def _concat(seq):
raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0])))


def _train_part(params, model_factory, list_of_parts, worker_address_to_port, return_model,
time_out=120, **kwargs):
def _train_part(
params: Dict[str, Any],
model_factory: Type[LGBMModel],
list_of_parts: List[Dict[str, _DaskPart]],
worker_address_to_port: Dict[str, int],
return_model: bool,
time_out: int = 120,
**kwargs: Any
) -> Optional[LGBMModel]:
local_worker_address = get_worker().address
machine_list = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port)
Expand Down Expand Up @@ -158,7 +170,7 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
return model if return_model else None


def _split_to_parts(data, is_matrix):
def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]:
parts = data.to_delayed()
if isinstance(parts, np.ndarray):
if is_matrix:
Expand All @@ -169,24 +181,33 @@ def _split_to_parts(data, is_matrix):
return parts


def _train(client, data, label, params, model_factory, sample_weight=None, group=None, **kwargs):
def _train(
client: Client,
data: _DaskMatrixLike,
label: _DaskCollection,
params: Dict[str, Any],
model_factory: Type[LGBMModel],
sample_weight: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None,
**kwargs: Any
) -> LGBMModel:
"""Inner train routine.
Parameters
----------
client : dask.distributed.Client
Dask client.
data : dask array of shape = [n_samples, n_features]
data : dask Array or dask DataFrame of shape = [n_samples, n_features]
Input feature matrix.
label : dask array of shape = [n_samples]
label : dask Array, dask DataFrame or dask Series of shape = [n_samples]
The target values (class labels in classification, real numbers in regression).
params : dict
Parameters passed to constructor of the local underlying model.
model_factory : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Class of the local underlying model.
sample_weight : array-like of shape = [n_samples] or None, optional (default=None)
sample_weight : dask Array, dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)
Weights of training data.
group : array-like or None, optional (default=None)
group : dask Array, dask DataFrame, Dask Series of shape = [n_samples] or None, optional (default=None)
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
Expand Down Expand Up @@ -301,7 +322,15 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
return results[0]


def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs):
def _predict_part(
part: _DaskPart,
model: LGBMModel,
raw_score: bool,
pred_proba: bool,
pred_leaf: bool,
pred_contrib: bool,
**kwargs: Any
) -> _DaskPart:
data = part.values if isinstance(part, pd_DataFrame) else part

if data.shape[0] == 0:
Expand Down Expand Up @@ -332,15 +361,23 @@ def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, *
return result


def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pred_contrib=False,
dtype=np.float32, **kwargs):
def _predict(
model: LGBMModel,
data: _DaskMatrixLike,
raw_score: bool = False,
pred_proba: bool = False,
pred_leaf: bool = False,
pred_contrib: bool = False,
dtype: _PredictionDtype = np.float32,
**kwargs: Any
) -> dask_Array:
"""Inner predict routine.
Parameters
----------
model : lightgbm.LGBMClassifier, lightgbm.LGBMRegressor, or lightgbm.LGBMRanker class
Fitted underlying model.
data : dask array of shape = [n_samples, n_features]
data : dask Array or dask DataFrame of shape = [n_samples, n_features]
Input feature matrix.
raw_score : bool, optional (default=False)
Whether to predict raw scores.
Expand All @@ -357,16 +394,16 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr
Returns
-------
predicted_result : dask array of shape = [n_samples] or shape = [n_samples, n_classes]
predicted_result : dask Array of shape = [n_samples] or shape = [n_samples, n_classes]
The predicted values.
X_leaves : dask array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
X_leaves : dask Array of shape = [n_samples, n_trees] or shape = [n_samples, n_trees * n_classes]
If ``pred_leaf=True``, the predicted leaf of every tree for each sample.
X_SHAP_values : dask array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes] or list with n_classes length of such objects
X_SHAP_values : dask Array of shape = [n_samples, n_features + 1] or shape = [n_samples, (n_features + 1) * n_classes]
If ``pred_contrib=True``, the feature contributions for each sample.
"""
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
if isinstance(data, dask_Frame):
if isinstance(data, dask_DataFrame):
return data.map_partitions(
_predict_part,
model=model,
Expand All @@ -392,11 +429,21 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr
**kwargs
)
else:
raise TypeError('Data must be either Dask array or dataframe. Got %s.' % str(type(data)))
raise TypeError('Data must be either dask Array or dask DataFrame. Got %s.' % str(type(data)))


class _DaskLGBMModel:
def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None, **kwargs):

def _fit(
self,
model_factory: Type[LGBMModel],
X: _DaskMatrixLike,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "_DaskLGBMModel":
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
if client is None:
Expand All @@ -420,13 +467,13 @@ def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None,

return self

def _to_local(self, model_factory):
def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
model = model_factory(**self.get_params())
self._copy_extra_params(self, model)
return model

@staticmethod
def _copy_extra_params(source, dest):
def _copy_extra_params(source: Union["_DaskLGBMModel", LGBMModel], dest: Union["_DaskLGBMModel", LGBMModel]) -> None:
params = source.get_params()
attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys())
Expand All @@ -437,7 +484,14 @@ def _copy_extra_params(source, dest):
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMClassifier."""

def fit(self, X, y, sample_weight=None, client=None, **kwargs):
def fit(
self,
X: _DaskMatrixLike,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(
model_factory=LGBMClassifier,
Expand All @@ -455,7 +509,7 @@ def fit(self, X, y, sample_weight=None, client=None, **kwargs):
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)

def predict(self, X, **kwargs):
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
return _predict(
model=self.to_local(),
Expand All @@ -466,7 +520,7 @@ def predict(self, X, **kwargs):

predict.__doc__ = LGBMClassifier.predict.__doc__

def predict_proba(self, X, **kwargs):
def predict_proba(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
return _predict(
model=self.to_local(),
Expand All @@ -477,7 +531,7 @@ def predict_proba(self, X, **kwargs):

predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__

def to_local(self):
def to_local(self) -> LGBMClassifier:
"""Create regular version of lightgbm.LGBMClassifier from the distributed version.
Returns
Expand All @@ -491,7 +545,14 @@ def to_local(self):
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRegressor."""

def fit(self, X, y, sample_weight=None, client=None, **kwargs):
def fit(
self,
X: _DaskMatrixLike,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "DaskLGBMRegressor":
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(
model_factory=LGBMRegressor,
Expand All @@ -509,7 +570,7 @@ def fit(self, X, y, sample_weight=None, client=None, **kwargs):
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)

def predict(self, X, **kwargs):
def predict(self, X: _DaskMatrixLike, **kwargs) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
return _predict(
model=self.to_local(),
Expand All @@ -519,7 +580,7 @@ def predict(self, X, **kwargs):

predict.__doc__ = LGBMRegressor.predict.__doc__

def to_local(self):
def to_local(self) -> LGBMRegressor:
"""Create regular version of lightgbm.LGBMRegressor from the distributed version.
Returns
Expand All @@ -533,7 +594,16 @@ def to_local(self):
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRanker."""

def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None, **kwargs):
def fit(
self,
X: _DaskMatrixLike,
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
init_score: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "DaskLGBMRanker":
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask')
Expand All @@ -555,13 +625,13 @@ def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _eval_set + _after_eval_set)

def predict(self, X, **kwargs):
def predict(self, X: _DaskMatrixLike, **kwargs: Any) -> dask_Array:
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
return _predict(self.to_local(), X, **kwargs)

predict.__doc__ = LGBMRanker.predict.__doc__

def to_local(self):
def to_local(self) -> LGBMRanker:
"""Create regular version of lightgbm.LGBMRanker from the distributed version.
Returns
Expand Down

0 comments on commit ea8e47e

Please sign in to comment.