-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[dask] Add type hints in Dask package #3866
Changes from 8 commits
18d3bea
e84f64e
444bdc8
67626a4
eb1df4e
33963cd
232e09b
b8a3c45
e889c0b
bfd9dc0
3d3b75f
b20ac37
7e41a97
b559bd0
a1512ee
9107746
ef6f131
b3c2d58
8b19e04
32b2626
c1afc1e
66cee64
58c09da
7b3a0e5
a3393a5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ | |
import socket | ||
from collections import defaultdict | ||
from copy import deepcopy | ||
from typing import Dict, Iterable | ||
from typing import Any, Dict, Iterable, List, Optional, Type, Union | ||
from urllib.parse import urlparse | ||
|
||
import numpy as np | ||
|
@@ -19,7 +19,12 @@ | |
from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat, | ||
SKLEARN_INSTALLED, | ||
DASK_INSTALLED, dask_Frame, dask_Array, delayed, Client, default_client, get_worker, wait) | ||
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker | ||
from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker | ||
|
||
_1DArrayLike = Union[List, np.ndarray] | ||
_DaskCollection = Union[dask_Array, dask_Frame] | ||
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix] | ||
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]] | ||
|
||
|
||
def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int: | ||
|
@@ -102,7 +107,7 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc | |
return worker_ip_to_port | ||
|
||
|
||
def _concat(seq): | ||
def _concat(seq: Iterable[_DaskPart]) -> _DaskPart: | ||
if isinstance(seq[0], np.ndarray): | ||
return np.concatenate(seq, axis=0) | ||
elif isinstance(seq[0], (pd_DataFrame, pd_Series)): | ||
|
@@ -113,8 +118,15 @@ def _concat(seq): | |
raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0]))) | ||
|
||
|
||
def _train_part(params, model_factory, list_of_parts, worker_address_to_port, return_model, | ||
time_out=120, **kwargs): | ||
def _train_part( | ||
params: Dict[str, Any], | ||
model_factory: Type[LGBMModel], | ||
list_of_parts: List[Dict[str, _DaskPart]], | ||
worker_address_to_port: Dict[str, int], | ||
return_model: bool, | ||
time_out: int = 120, | ||
**kwargs: Any | ||
) -> Optional[LGBMModel]: | ||
local_worker_address = get_worker().address | ||
machine_list = ','.join([ | ||
'%s:%d' % (urlparse(worker_address).hostname, port) | ||
|
@@ -158,7 +170,7 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re | |
return model if return_model else None | ||
|
||
|
||
def _split_to_parts(data, is_matrix): | ||
def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]: | ||
parts = data.to_delayed() | ||
if isinstance(parts, np.ndarray): | ||
if is_matrix: | ||
|
@@ -169,7 +181,16 @@ def _split_to_parts(data, is_matrix): | |
return parts | ||
|
||
|
||
def _train(client, data, label, params, model_factory, sample_weight=None, group=None, **kwargs): | ||
def _train( | ||
client: Client, | ||
data: _DaskCollection, | ||
StrikerRUS marked this conversation as resolved.
Show resolved
Hide resolved
|
||
label: _DaskCollection, | ||
params: Dict[str, Any], | ||
model_factory: Type[LGBMModel], | ||
sample_weight: Optional[_DaskCollection] = None, | ||
group: Optional[_1DArrayLike] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Would be great if the docs also reflected that There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll update the docs in this PR too, I think that's directly related to this scope. Thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thinking through this, I realized that with the way we do doc inheritance, it will actually be a bit tricky to override the docs for I wrote up my thoughts in #3871, but I won't update the docs (except those on internal functions) in this PR There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed the hint in https://docs.dask.org/en/latest/dataframe-api.html#dask.dataframe.to_csv |
||
**kwargs: Any | ||
) -> LGBMModel: | ||
"""Inner train routine. | ||
|
||
Parameters | ||
|
@@ -303,7 +324,15 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group | |
return results[0] | ||
|
||
|
||
def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs): | ||
def _predict_part( | ||
part: _DaskPart, | ||
model: LGBMModel, | ||
raw_score: bool, | ||
pred_proba: bool, | ||
pred_leaf: bool, | ||
pred_contrib: bool, | ||
**kwargs: Any | ||
) -> _DaskPart: | ||
data = part.values if isinstance(part, pd_DataFrame) else part | ||
|
||
if data.shape[0] == 0: | ||
|
@@ -334,8 +363,16 @@ def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, * | |
return result | ||
|
||
|
||
def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pred_contrib=False, | ||
dtype=np.float32, **kwargs): | ||
def _predict( | ||
model: LGBMModel, | ||
data: _DaskCollection, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The corresponding docstr for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yup will do in this PR, thanks! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed in bfd9dc0
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raw_score: bool = False, | ||
pred_proba: bool = False, | ||
pred_leaf: bool = False, | ||
pred_contrib: bool = False, | ||
dtype: _PredictionDtype = np.float32, | ||
**kwargs: Any | ||
) -> _DaskCollection: | ||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Inner predict routine. | ||
|
||
Parameters | ||
|
@@ -398,7 +435,17 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr | |
|
||
|
||
class _DaskLGBMModel: | ||
def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None, **kwargs): | ||
|
||
def _fit( | ||
self, | ||
model_factory: Type[LGBMModel], | ||
X: _DaskCollection, | ||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
y: _DaskCollection, | ||
sample_weight: Optional[_DaskCollection] = None, | ||
group: Optional[_DaskCollection] = None, | ||
client: Optional[Client] = None, | ||
**kwargs: Any | ||
) -> "_DaskLGBMModel": | ||
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)): | ||
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask') | ||
if client is None: | ||
|
@@ -422,13 +469,13 @@ def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None, | |
|
||
return self | ||
|
||
def _to_local(self, model_factory): | ||
def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel: | ||
model = model_factory(**self.get_params()) | ||
self._copy_extra_params(self, model) | ||
return model | ||
|
||
@staticmethod | ||
def _copy_extra_params(source, dest): | ||
def _copy_extra_params(source: "_DaskLGBMModel", dest: "_DaskLGBMModel") -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shoot, you're totally right. good eye There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. fixed in bfd9dc0 |
||
params = source.get_params() | ||
attributes = source.__dict__ | ||
extra_param_names = set(attributes.keys()).difference(params.keys()) | ||
|
@@ -439,7 +486,14 @@ def _copy_extra_params(source, dest): | |
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel): | ||
"""Distributed version of lightgbm.LGBMClassifier.""" | ||
|
||
def fit(self, X, y, sample_weight=None, client=None, **kwargs): | ||
def fit( | ||
self, | ||
X: _DaskCollection, | ||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
y: _DaskCollection, | ||
sample_weight: Optional[_DaskCollection] = None, | ||
client: Optional[Client] = None, | ||
**kwargs: Any | ||
) -> "DaskLGBMClassifier": | ||
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit.""" | ||
return self._fit( | ||
model_factory=LGBMClassifier, | ||
|
@@ -457,7 +511,7 @@ def fit(self, X, y, sample_weight=None, client=None, **kwargs): | |
+ ' ' * 12 + 'Dask client.\n' | ||
+ ' ' * 8 + _init_score + _after_init_score) | ||
|
||
def predict(self, X, **kwargs): | ||
def predict(self, X: _DaskCollection, **kwargs: Any) -> _DaskCollection: | ||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict.""" | ||
return _predict( | ||
model=self.to_local(), | ||
|
@@ -468,7 +522,7 @@ def predict(self, X, **kwargs): | |
|
||
predict.__doc__ = LGBMClassifier.predict.__doc__ | ||
|
||
def predict_proba(self, X, **kwargs): | ||
def predict_proba(self, X: _DaskCollection, **kwargs: Any) -> _DaskCollection: | ||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba.""" | ||
return _predict( | ||
model=self.to_local(), | ||
|
@@ -479,7 +533,7 @@ def predict_proba(self, X, **kwargs): | |
|
||
predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__ | ||
|
||
def to_local(self): | ||
def to_local(self) -> LGBMClassifier: | ||
"""Create regular version of lightgbm.LGBMClassifier from the distributed version. | ||
|
||
Returns | ||
|
@@ -493,7 +547,14 @@ def to_local(self): | |
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel): | ||
"""Distributed version of lightgbm.LGBMRegressor.""" | ||
|
||
def fit(self, X, y, sample_weight=None, client=None, **kwargs): | ||
def fit( | ||
self, | ||
X: _DaskCollection, | ||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
y: _DaskCollection, | ||
sample_weight: Optional[_DaskCollection] = None, | ||
client: Optional[Client] = None, | ||
**kwargs: Any | ||
) -> "DaskLGBMRegressor": | ||
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit.""" | ||
return self._fit( | ||
model_factory=LGBMRegressor, | ||
|
@@ -511,7 +572,7 @@ def fit(self, X, y, sample_weight=None, client=None, **kwargs): | |
+ ' ' * 12 + 'Dask client.\n' | ||
+ ' ' * 8 + _init_score + _after_init_score) | ||
|
||
def predict(self, X, **kwargs): | ||
def predict(self, X: _DaskCollection, **kwargs) -> _DaskCollection: | ||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict.""" | ||
return _predict( | ||
model=self.to_local(), | ||
|
@@ -521,7 +582,7 @@ def predict(self, X, **kwargs): | |
|
||
predict.__doc__ = LGBMRegressor.predict.__doc__ | ||
|
||
def to_local(self): | ||
def to_local(self) -> LGBMRegressor: | ||
"""Create regular version of lightgbm.LGBMRegressor from the distributed version. | ||
|
||
Returns | ||
|
@@ -535,7 +596,16 @@ def to_local(self): | |
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel): | ||
"""Distributed version of lightgbm.LGBMRanker.""" | ||
|
||
def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None, **kwargs): | ||
def fit( | ||
self, | ||
X: _DaskCollection, | ||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
y: _DaskCollection, | ||
sample_weight: Optional[_DaskCollection] = None, | ||
init_score: Optional[_DaskCollection] = None, | ||
group: Optional[_1DArrayLike] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think this might be If this comment stands, then There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oooooo ok! I misunderstood what was happening. Awesome, I'll look again and then change that hint. I'll also write up a feature request for this. I think it's something that's non-breaking and additive, that could be done after 3.2.0. But like you said, since |
||
client: Optional[Client] = None, | ||
**kwargs: Any | ||
) -> "DaskLGBMRanker": | ||
"""Docstring is inherited from the lightgbm.LGBMRanker.fit.""" | ||
if init_score is not None: | ||
raise RuntimeError('init_score is not currently supported in lightgbm.dask') | ||
|
@@ -557,13 +627,13 @@ def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None | |
+ ' ' * 12 + 'Dask client.\n' | ||
+ ' ' * 8 + _eval_set + _after_eval_set) | ||
|
||
def predict(self, X, **kwargs): | ||
def predict(self, X: _DaskCollection, **kwargs: Any) -> _DaskCollection: | ||
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
jameslamb marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Docstring is inherited from the lightgbm.LGBMRanker.predict.""" | ||
return _predict(self.to_local(), X, **kwargs) | ||
|
||
predict.__doc__ = LGBMRanker.predict.__doc__ | ||
|
||
def to_local(self): | ||
def to_local(self) -> LGBMRanker: | ||
"""Create regular version of lightgbm.LGBMRanker from the distributed version. | ||
|
||
Returns | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I get that technically any
Iterable
(tuple/list/np.array/...) could work forseq
here, but what is the benefit to usingIterable
and notList
? I guess you would never use theList
type hint for an input parameter (i.e. always preferIterable
overList
) unless the function called somelist
-specific method like.sort
or mutability...?I only raise this because I came across this S/O comment: https://stackoverflow.com/a/52827511/14480058, not sure what to think.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't have any special reason for choosing
Iterable
instead ofList
. I was just looking at the code and didn't see any list-specific stuff so I thought it made sense to use the more general thing.But since this is a totally-internally function, where we control the input, and since
Iterable
does weird stuff as that S/O post points out, I'll change this toList[]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed in bfd9dc0