Skip to content

Commit

Permalink
[dask] add support for custom objective functions (fixes #3934) (#4920)
Browse files Browse the repository at this point in the history
* add test for custom objective with regressor

* add test for custom binary classification objective with classifier

* isort

* got tests working for multiclass

* update docs

* train deeper model for classifier

* Apply suggestions from code review

Co-authored-by: José Morales <[email protected]>

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* update multiclass tests

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* fix multiclass probabilities

* linting

Co-authored-by: José Morales <[email protected]>
Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2022
1 parent 4aaeb22 commit a06fadf
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 18 deletions.
35 changes: 35 additions & 0 deletions docs/Parallel-Learning-Guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,41 @@ You could edit your firewall rules to allow communication between any of the wor
* the port ``local_listen_port`` is not open on any of the worker hosts
* any machine has multiple Dask worker processes running on it

Using Custom Objective Functions with Dask
******************************************

It is possible to customize the boosting process by providing a custom objective function written in Python.
See the Dask API's documentation for details on how to implement such functions.

.. warning::

Custom objective functions used with ``lightgbm.dask`` will be called by each worker process on only that worker's local data.

Follow the example below to use a custom implementation of the ``regression_l2`` objective.

.. code:: python
import dask.array as da
import lightgbm as lgb
import numpy as np
from distributed import Client, LocalCluster
cluster = LocalCluster(n_workers=2)
client = Client(cluster)
X = da.random.random((1000, 10), (500, 10))
y = da.random.random((1000,), (500,))
def custom_l2_obj(y_true, y_pred):
grad = y_pred - y_true
hess = np.ones(len(y_true))
return grad, hess
dask_model = lgb.DaskLGBMRegressor(
objective=custom_l2_obj
)
dask_model.fit(X, y)
Prediction with Dask
''''''''''''''''''''

Expand Down
26 changes: 8 additions & 18 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction,
_lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
_LGBM_ScikitCustomObjectiveFunction, _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit,
_lgbmmodel_doc_predict)

_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
Expand Down Expand Up @@ -1099,7 +1100,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[str] = None,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1142,16 +1143,12 @@ def __init__(

_base_doc = LGBMClassifier.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') # type: ignore
_base_doc = f"""
__init__.__doc__ = f"""
{_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
{' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
{_kwargs}{_after_kwargs}
"""

# the note on custom objective functions in LGBMModel.__init__ is not
# currently relevant for the Dask estimators
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_dask_getstate()

Expand Down Expand Up @@ -1275,7 +1272,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[str] = None,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1318,14 +1315,11 @@ def __init__(

_base_doc = LGBMRegressor.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') # type: ignore
_base_doc = f"""
__init__.__doc__ = f"""
{_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
{' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
{_kwargs}{_after_kwargs}
"""
# the note on custom objective functions in LGBMModel.__init__ is not
# currently relevant for the Dask estimators
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_dask_getstate()
Expand Down Expand Up @@ -1431,7 +1425,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[str] = None,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1474,16 +1468,12 @@ def __init__(

_base_doc = LGBMRanker.__init__.__doc__
_before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') # type: ignore
_base_doc = f"""
__init__.__doc__ = f"""
{_before_kwargs}client : dask.distributed.Client or None, optional (default=None)
{' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled.
{_kwargs}{_after_kwargs}
"""

# the note on custom objective functions in LGBMModel.__init__ is not
# currently relevant for the Dask estimators
__init__.__doc__ = _base_doc[:_base_doc.find('Note\n')]

def __getstate__(self) -> Dict[Any, Any]:
return self._lgb_dask_getstate()

Expand Down
216 changes: 216 additions & 0 deletions tests/python_package_test/test_dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,38 @@ def _unpickle(filepath, serializer):
raise ValueError(f'Unrecognized serializer type: {serializer}')


def _objective_least_squares(y_true, y_pred):
grad = y_pred - y_true
hess = np.ones(len(y_true))
return grad, hess


def _objective_logistic_regression(y_true, y_pred):
y_pred = 1.0 / (1.0 + np.exp(-y_pred))
grad = y_pred - y_true
hess = y_pred * (1.0 - y_pred)
return grad, hess


def _objective_logloss(y_true, y_pred):
num_rows = len(y_true)
num_class = len(np.unique(y_true))
# operate on preds as [num_data, num_classes] matrix
y_pred = y_pred.reshape(-1, num_class, order='F')
row_wise_max = np.max(y_pred, axis=1).reshape(num_rows, 1)
preds = y_pred - row_wise_max
prob = np.exp(preds) / np.sum(np.exp(preds), axis=1).reshape(num_rows, 1)
grad_update = np.zeros_like(preds)
grad_update[np.arange(num_rows), y_true.astype(np.int32)] = -1.0
grad = prob + grad_update
factor = num_class / (num_class - 1)
hess = factor * prob * (1 - prob)
# reshape back to 1-D array, grouped by class id and then row id
grad = grad.T.reshape(-1)
hess = hess.T.reshape(-1)
return grad, hess


@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
@pytest.mark.parametrize('boosting_type', boosting_types)
Expand Down Expand Up @@ -455,6 +487,79 @@ def test_classifier_pred_contrib(output, task, cluster):
assert len(np.unique(preds_with_contrib[:, base_value_col]) == 1)


@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification'])
def test_classifier_custom_objective(output, task, cluster):
with Client(cluster) as client:
X, y, w, _, dX, dy, dw, _ = _create_data(
objective=task,
output=output,
)

params = {
"n_estimators": 50,
"num_leaves": 31,
"verbose": -1,
"seed": 708,
"deterministic": True,
"force_col_wise": True
}

if task == 'binary-classification':
params.update({
'objective': _objective_logistic_regression,
})
elif task == 'multiclass-classification':
params.update({
'objective': _objective_logloss,
'num_classes': 3
})

dask_classifier = lgb.DaskLGBMClassifier(
client=client,
time_out=5,
tree_learner='data',
**params
)
dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw)
dask_classifier_local = dask_classifier.to_local()
p1_raw = dask_classifier.predict(dX, raw_score=True).compute()
p1_raw_local = dask_classifier_local.predict(X, raw_score=True)

local_classifier = lgb.LGBMClassifier(**params)
local_classifier.fit(X, y, sample_weight=w)
p2_raw = local_classifier.predict(X, raw_score=True)

# with a custom objective, prediction result is a raw score instead of predicted class
if task == 'binary-classification':
p1_proba = 1.0 / (1.0 + np.exp(-p1_raw))
p1_class = (p1_proba > 0.5).astype(np.int64)
p1_proba_local = 1.0 / (1.0 + np.exp(-p1_raw_local))
p1_class_local = (p1_proba_local > 0.5).astype(np.int64)
p2_proba = 1.0 / (1.0 + np.exp(-p2_raw))
p2_class = (p2_proba > 0.5).astype(np.int64)
elif task == 'multiclass-classification':
p1_proba = np.exp(p1_raw) / np.sum(np.exp(p1_raw), axis=1).reshape(-1, 1)
p1_class = p1_proba.argmax(axis=1)
p1_proba_local = np.exp(p1_raw_local) / np.sum(np.exp(p1_raw_local), axis=1).reshape(-1, 1)
p1_class_local = p1_proba_local.argmax(axis=1)
p2_proba = np.exp(p2_raw) / np.sum(np.exp(p2_raw), axis=1).reshape(-1, 1)
p2_class = p2_proba.argmax(axis=1)

# function should have been preserved
assert callable(dask_classifier.objective_)
assert callable(dask_classifier_local.objective_)

# should correctly classify every sample
assert_eq(p1_class, y)
assert_eq(p1_class_local, y)
assert_eq(p2_class, y)

# probability estimates should be similar
assert_eq(p1_proba, p2_proba, atol=0.03)
assert_eq(p1_proba, p1_proba_local)


def test_group_workers_by_host():
hosts = [f'0.0.0.{i}' for i in range(2)]
workers = [f'tcp://{host}:{p}' for p in range(2) for host in hosts]
Expand Down Expand Up @@ -700,6 +805,56 @@ def test_regressor_quantile(output, alpha, cluster):
assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '=='


@pytest.mark.parametrize('output', data_output)
def test_regressor_custom_objective(output, cluster):
with Client(cluster) as client:
X, y, w, _, dX, dy, dw, _ = _create_data(
objective='regression',
output=output
)

params = {
"n_estimators": 10,
"num_leaves": 10,
"objective": _objective_least_squares
}

dask_regressor = lgb.DaskLGBMRegressor(
client=client,
time_out=5,
tree_learner='data',
**params
)
dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw)
dask_regressor_local = dask_regressor.to_local()
p1 = dask_regressor.predict(dX)
p1_local = dask_regressor_local.predict(X)
s1_local = dask_regressor_local.score(X, y)
s1 = _r2_score(dy, p1)
p1 = p1.compute()

local_regressor = lgb.LGBMRegressor(**params)
local_regressor.fit(X, y, sample_weight=w)
p2 = local_regressor.predict(X)
s2 = local_regressor.score(X, y)

# function should have been preserved
assert callable(dask_regressor.objective_)
assert callable(dask_regressor_local.objective_)

# Scores should be the same
assert_eq(s1, s2, atol=0.01)
assert_eq(s1, s1_local)

# local and Dask predictions should be the same
assert_eq(p1, p1_local)

# predictions should be better than random
assert_precision = {"rtol": 0.5, "atol": 50.}
assert_eq(p1, y, **assert_precision)
assert_eq(p2, y, **assert_precision)


@pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical'])
@pytest.mark.parametrize('group', [None, group_sizes])
@pytest.mark.parametrize('boosting_type', boosting_types)
Expand Down Expand Up @@ -808,6 +963,67 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster):
assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '=='


@pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical'])
def test_ranker_custom_objective(output, cluster):
with Client(cluster) as client:
if output == 'dataframe-with-categorical':
X, y, w, g, dX, dy, dw, dg = _create_data(
objective='ranking',
output=output,
group=group_sizes,
n_features=1,
n_informative=1
)
else:
X, y, w, g, dX, dy, dw, dg = _create_data(
objective='ranking',
output=output,
group=group_sizes
)

# rebalance small dask.Array dataset for better performance.
if output == 'array':
dX = dX.persist()
dy = dy.persist()
dw = dw.persist()
dg = dg.persist()
_ = wait([dX, dy, dw, dg])
client.rebalance()

params = {
"random_state": 42,
"n_estimators": 50,
"num_leaves": 20,
"min_child_samples": 1,
"objective": _objective_least_squares
}

dask_ranker = lgb.DaskLGBMRanker(
client=client,
time_out=5,
tree_learner_type="data",
**params
)
dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg)
rnkvec_dask = dask_ranker.predict(dX).compute()
dask_ranker_local = dask_ranker.to_local()
rnkvec_dask_local = dask_ranker_local.predict(X)

local_ranker = lgb.LGBMRanker(**params)
local_ranker.fit(X, y, sample_weight=w, group=g)
rnkvec_local = local_ranker.predict(X)

# distributed ranker should be able to rank decently well with the least-squares objective
# and should have high rank correlation with scores from serial ranker.
assert spearmanr(rnkvec_dask, y).correlation > 0.6
assert spearmanr(rnkvec_dask, rnkvec_local).correlation > 0.8
assert_eq(rnkvec_dask, rnkvec_dask_local)

# function should have been preserved
assert callable(dask_ranker.objective_)
assert callable(dask_ranker_local.objective_)


@pytest.mark.parametrize('task', tasks)
@pytest.mark.parametrize('output', data_output)
@pytest.mark.parametrize('eval_sizes', [[0.5, 1, 1.5], [0]])
Expand Down

0 comments on commit a06fadf

Please sign in to comment.