Skip to content

Commit

Permalink
Support custom metric in sklearn ranker. (#8786)
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis authored Feb 12, 2023
1 parent 17b709a commit 225b315
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 7 deletions.
53 changes: 48 additions & 5 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import os
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -127,6 +128,49 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
return inner


def ltr_metric_decorator(func: Callable, n_jobs: Optional[int]) -> Metric:
"""Decorate a learning to rank metric."""

def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]:
y_true = dmatrix.get_label()
group_ptr = dmatrix.get_uint_info("group_ptr")
if group_ptr.size < 2:
raise ValueError(
"Invalid `group_ptr`. Likely caused by invalid qid or group."
)
scores = np.empty(group_ptr.size - 1)
futures = []
weight = dmatrix.get_group()
no_weight = weight.size == 0

def task(i: int) -> float:
begin = group_ptr[i - 1]
end = group_ptr[i]
gy = y_true[begin:end]
gp = y_score[begin:end]
if gy.size == 1:
# Maybe there's a better default? 1.0 because many ranking score
# functions have output in range [0, 1].
return 1.0
return func(gy, gp)

workers = n_jobs if n_jobs is not None else os.cpu_count()
with ThreadPoolExecutor(max_workers=workers) as executor:
for i in range(1, group_ptr.size):
f = executor.submit(task, i)
futures.append(f)

for i, f in enumerate(futures):
scores[i] = f.result()

if no_weight:
return func.__name__, scores.mean()

return func.__name__, np.average(scores, weights=weight)

return inner


__estimator_doc = """
n_estimators : int
Number of gradient boosted trees. Equivalent to number of boosting
Expand Down Expand Up @@ -868,7 +912,10 @@ def _duplicated(parameter: str) -> None:
metric = eval_metric
elif callable(eval_metric):
# Parameter from constructor or set_params
metric = _metric_decorator(eval_metric)
if self._get_type() == "ranker":
metric = ltr_metric_decorator(eval_metric, self.n_jobs)
else:
metric = _metric_decorator(eval_metric)
else:
params.update({"eval_metric": eval_metric})

Expand Down Expand Up @@ -1979,10 +2026,6 @@ def fit(
) = self._configure_fit(
xgb_model, eval_metric, params, early_stopping_rounds, callbacks
)
if callable(metric):
raise ValueError(
"Custom evaluation metric is not yet supported for XGBRanker."
)

self._Booster = train(
params,
Expand Down
30 changes: 28 additions & 2 deletions tests/python/test_with_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,32 @@ def test_ranking():
np.testing.assert_almost_equal(pred, pred_orig)


def test_ranking_metric() -> None:
from sklearn.metrics import roc_auc_score

X, y, qid, w = tm.make_ltr(512, 4, 3, 2)
# use auc for test as ndcg_score in sklearn works only on label gain instead of exp
# gain.
# note that the auc in sklearn is different from the one in XGBoost. The one in
# sklearn compares the number of mis-classified docs, while the one in xgboost
# compares the number of mis-classified pairs.
ltr = xgb.XGBRanker(
eval_metric=roc_auc_score, n_estimators=10, tree_method="hist", max_depth=2
)
ltr.fit(
X,
y,
qid=qid,
sample_weight=w,
eval_set=[(X, y)],
eval_qid=[qid],
sample_weight_eval_set=[w],
verbose=True,
)
results = ltr.evals_result()
assert results["validation_0"]["roc_auc_score"][-1] > 0.6


def test_stacking_regression():
from sklearn.datasets import load_diabetes
from sklearn.ensemble import RandomForestRegressor, StackingRegressor
Expand Down Expand Up @@ -1426,10 +1452,10 @@ def test_weighted_evaluation_metric():
X_train, X_test = X[:1600], X[1600:]
y_train, y_test = y[:1600], y[1600:]
weights_eval_set = np.random.choice([1, 2], len(X_test))

np.random.seed(0)
weights_train = np.random.choice([1, 2], len(X_train))

clf = xgb.XGBClassifier(
tree_method="hist",
eval_metric=log_loss,
Expand Down

0 comments on commit 225b315

Please sign in to comment.