Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dask] Add type hints in Dask package #3866

Merged
merged 25 commits into from
Jan 29, 2021
Merged
Changes from 8 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 93 additions & 23 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import socket
from collections import defaultdict
from copy import deepcopy
from typing import Dict, Iterable
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from urllib.parse import urlparse

import numpy as np
Expand All @@ -19,7 +19,12 @@
from .compat import (PANDAS_INSTALLED, pd_DataFrame, pd_Series, concat,
SKLEARN_INSTALLED,
DASK_INSTALLED, dask_Frame, dask_Array, delayed, Client, default_client, get_worker, wait)
from .sklearn import LGBMClassifier, LGBMRegressor, LGBMRanker
from .sklearn import LGBMClassifier, LGBMModel, LGBMRegressor, LGBMRanker

_1DArrayLike = Union[List, np.ndarray]
_DaskCollection = Union[dask_Array, dask_Frame]
_DaskPart = Union[np.ndarray, pd_DataFrame, pd_Series, ss.spmatrix]
_PredictionDtype = Union[Type[np.float32], Type[np.float64], Type[np.int32], Type[np.int64]]


def _find_open_port(worker_ip: str, local_listen_port: int, ports_to_skip: Iterable[int]) -> int:
Expand Down Expand Up @@ -102,7 +107,7 @@ def _find_ports_for_workers(client: Client, worker_addresses: Iterable[str], loc
return worker_ip_to_port


def _concat(seq):
def _concat(seq: Iterable[_DaskPart]) -> _DaskPart:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get that technically any Iterable (tuple/list/np.array/...) could work for seq here, but what is the benefit to using Iterable and not List? I guess you would never use the List type hint for an input parameter (i.e. always prefer Iterable over List) unless the function called some list-specific method like .sort or mutability...?

I only raise this because I came across this S/O comment: https://stackoverflow.com/a/52827511/14480058, not sure what to think.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't have any special reason for choosing Iterable instead of List. I was just looking at the code and didn't see any list-specific stuff so I thought it made sense to use the more general thing.

But since this is a totally-internally function, where we control the input, and since Iterable does weird stuff as that S/O post points out, I'll change this to List[]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in bfd9dc0

if isinstance(seq[0], np.ndarray):
return np.concatenate(seq, axis=0)
elif isinstance(seq[0], (pd_DataFrame, pd_Series)):
Expand All @@ -113,8 +118,15 @@ def _concat(seq):
raise TypeError('Data must be one of: numpy arrays, pandas dataframes, sparse matrices (from scipy). Got %s.' % str(type(seq[0])))


def _train_part(params, model_factory, list_of_parts, worker_address_to_port, return_model,
time_out=120, **kwargs):
def _train_part(
params: Dict[str, Any],
model_factory: Type[LGBMModel],
list_of_parts: List[Dict[str, _DaskPart]],
worker_address_to_port: Dict[str, int],
return_model: bool,
time_out: int = 120,
**kwargs: Any
) -> Optional[LGBMModel]:
local_worker_address = get_worker().address
machine_list = ','.join([
'%s:%d' % (urlparse(worker_address).hostname, port)
Expand Down Expand Up @@ -158,7 +170,7 @@ def _train_part(params, model_factory, list_of_parts, worker_address_to_port, re
return model if return_model else None


def _split_to_parts(data, is_matrix):
def _split_to_parts(data: _DaskCollection, is_matrix: bool) -> List[_DaskPart]:
parts = data.to_delayed()
if isinstance(parts, np.ndarray):
if is_matrix:
Expand All @@ -169,7 +181,16 @@ def _split_to_parts(data, is_matrix):
return parts


def _train(client, data, label, params, model_factory, sample_weight=None, group=None, **kwargs):
def _train(
client: Client,
data: _DaskCollection,
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
label: _DaskCollection,
params: Dict[str, Any],
model_factory: Type[LGBMModel],
sample_weight: Optional[_DaskCollection] = None,
group: Optional[_1DArrayLike] = None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

group: Optional[_DaskCollection] = None,

Would be great if the docs also reflected that sample_weight and group are, up to this point, still distributed vectors

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll update the docs in this PR too, I think that's directly related to this scope. Thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking through this, I realized that with the way we do doc inheritance, it will actually be a bit tricky to override the docs for group.

I wrote up my thoughts in #3871, but I won't update the docs (except those on internal functions) in this PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

**kwargs: Any
) -> LGBMModel:
"""Inner train routine.

Parameters
Expand Down Expand Up @@ -303,7 +324,15 @@ def _train(client, data, label, params, model_factory, sample_weight=None, group
return results[0]


def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, **kwargs):
def _predict_part(
part: _DaskPart,
model: LGBMModel,
raw_score: bool,
pred_proba: bool,
pred_leaf: bool,
pred_contrib: bool,
**kwargs: Any
) -> _DaskPart:
data = part.values if isinstance(part, pd_DataFrame) else part

if data.shape[0] == 0:
Expand Down Expand Up @@ -334,8 +363,16 @@ def _predict_part(part, model, raw_score, pred_proba, pred_leaf, pred_contrib, *
return result


def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pred_contrib=False,
dtype=np.float32, **kwargs):
def _predict(
model: LGBMModel,
data: _DaskCollection,
Copy link
Contributor

@ffineis ffineis Jan 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The corresponding docstr for the data parameter reads dask array of shape = [n_samples, n_features] - consider changing it to data : dask array or dask DataFrame of shape = [n_samples, n_features]?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup will do in this PR, thanks!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in bfd9dc0

jameslamb marked this conversation as resolved.
Show resolved Hide resolved
raw_score: bool = False,
pred_proba: bool = False,
pred_leaf: bool = False,
pred_contrib: bool = False,
dtype: _PredictionDtype = np.float32,
**kwargs: Any
) -> _DaskCollection:
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
"""Inner predict routine.

Parameters
Expand Down Expand Up @@ -398,7 +435,17 @@ def _predict(model, data, raw_score=False, pred_proba=False, pred_leaf=False, pr


class _DaskLGBMModel:
def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None, **kwargs):

def _fit(
self,
model_factory: Type[LGBMModel],
X: _DaskCollection,
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
group: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "_DaskLGBMModel":
if not all((DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED)):
raise LightGBMError('dask, pandas and scikit-learn are required for lightgbm.dask')
if client is None:
Expand All @@ -422,13 +469,13 @@ def _fit(self, model_factory, X, y, sample_weight=None, group=None, client=None,

return self

def _to_local(self, model_factory):
def _to_local(self, model_factory: Type[LGBMModel]) -> LGBMModel:
model = model_factory(**self.get_params())
self._copy_extra_params(self, model)
return model

@staticmethod
def _copy_extra_params(source, dest):
def _copy_extra_params(source: "_DaskLGBMModel", dest: "_DaskLGBMModel") -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that dest could refer to both an LGBMModel or a _DaskLGBMModel, referring to line to 473 (in _to_local).

def _copy_extra_params(source: "_DaskLGBMModel", dest: Union["_DaskLGBMModel", LGBMModel]) -> None:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shoot, you're totally right. good eye

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in bfd9dc0

params = source.get_params()
attributes = source.__dict__
extra_param_names = set(attributes.keys()).difference(params.keys())
Expand All @@ -439,7 +486,14 @@ def _copy_extra_params(source, dest):
class DaskLGBMClassifier(LGBMClassifier, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMClassifier."""

def fit(self, X, y, sample_weight=None, client=None, **kwargs):
def fit(
self,
X: _DaskCollection,
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "DaskLGBMClassifier":
"""Docstring is inherited from the lightgbm.LGBMClassifier.fit."""
return self._fit(
model_factory=LGBMClassifier,
Expand All @@ -457,7 +511,7 @@ def fit(self, X, y, sample_weight=None, client=None, **kwargs):
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)

def predict(self, X, **kwargs):
def predict(self, X: _DaskCollection, **kwargs: Any) -> _DaskCollection:
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict."""
return _predict(
model=self.to_local(),
Expand All @@ -468,7 +522,7 @@ def predict(self, X, **kwargs):

predict.__doc__ = LGBMClassifier.predict.__doc__

def predict_proba(self, X, **kwargs):
def predict_proba(self, X: _DaskCollection, **kwargs: Any) -> _DaskCollection:
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
"""Docstring is inherited from the lightgbm.LGBMClassifier.predict_proba."""
return _predict(
model=self.to_local(),
Expand All @@ -479,7 +533,7 @@ def predict_proba(self, X, **kwargs):

predict_proba.__doc__ = LGBMClassifier.predict_proba.__doc__

def to_local(self):
def to_local(self) -> LGBMClassifier:
"""Create regular version of lightgbm.LGBMClassifier from the distributed version.

Returns
Expand All @@ -493,7 +547,14 @@ def to_local(self):
class DaskLGBMRegressor(LGBMRegressor, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRegressor."""

def fit(self, X, y, sample_weight=None, client=None, **kwargs):
def fit(
self,
X: _DaskCollection,
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
client: Optional[Client] = None,
**kwargs: Any
) -> "DaskLGBMRegressor":
"""Docstring is inherited from the lightgbm.LGBMRegressor.fit."""
return self._fit(
model_factory=LGBMRegressor,
Expand All @@ -511,7 +572,7 @@ def fit(self, X, y, sample_weight=None, client=None, **kwargs):
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _init_score + _after_init_score)

def predict(self, X, **kwargs):
def predict(self, X: _DaskCollection, **kwargs) -> _DaskCollection:
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
"""Docstring is inherited from the lightgbm.LGBMRegressor.predict."""
return _predict(
model=self.to_local(),
Expand All @@ -521,7 +582,7 @@ def predict(self, X, **kwargs):

predict.__doc__ = LGBMRegressor.predict.__doc__

def to_local(self):
def to_local(self) -> LGBMRegressor:
"""Create regular version of lightgbm.LGBMRegressor from the distributed version.

Returns
Expand All @@ -535,7 +596,16 @@ def to_local(self):
class DaskLGBMRanker(LGBMRanker, _DaskLGBMModel):
"""Distributed version of lightgbm.LGBMRanker."""

def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None, **kwargs):
def fit(
self,
X: _DaskCollection,
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
y: _DaskCollection,
sample_weight: Optional[_DaskCollection] = None,
init_score: Optional[_DaskCollection] = None,
group: Optional[_1DArrayLike] = None,
Copy link
Contributor

@ffineis ffineis Jan 27, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think this might be group: Optional[_DaskCollection], group array is still distributed at this point. Since group can be so much smaller than both X and y, I think supporting a locally-defined group input list or array is a noble cause, but this would need to be its own PR in which it was defined which parts of group get sent where to accompany the X and y parts.

If this comment stands, then _1DArrayLike can be removed at the top of this file.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oooooo ok! I misunderstood what was happening. Awesome, I'll look again and then change that hint.

I'll also write up a feature request for this. I think it's something that's non-breaking and additive, that could be done after 3.2.0. But like you said, since group is often fairly small, I think it could be a nice thing to be able to specify it as a list or lil numpy array.

client: Optional[Client] = None,
**kwargs: Any
) -> "DaskLGBMRanker":
"""Docstring is inherited from the lightgbm.LGBMRanker.fit."""
if init_score is not None:
raise RuntimeError('init_score is not currently supported in lightgbm.dask')
Expand All @@ -557,13 +627,13 @@ def fit(self, X, y, sample_weight=None, init_score=None, group=None, client=None
+ ' ' * 12 + 'Dask client.\n'
+ ' ' * 8 + _eval_set + _after_eval_set)

def predict(self, X, **kwargs):
def predict(self, X: _DaskCollection, **kwargs: Any) -> _DaskCollection:
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
"""Docstring is inherited from the lightgbm.LGBMRanker.predict."""
return _predict(self.to_local(), X, **kwargs)

predict.__doc__ = LGBMRanker.predict.__doc__

def to_local(self):
def to_local(self) -> LGBMRanker:
"""Create regular version of lightgbm.LGBMRanker from the distributed version.

Returns
Expand Down