diff --git a/python-package/xgboost/sklearn.py b/python-package/xgboost/sklearn.py index 946c87235b63..69bcac38d01a 100644 --- a/python-package/xgboost/sklearn.py +++ b/python-package/xgboost/sklearn.py @@ -4,6 +4,7 @@ import json import os import warnings +from concurrent.futures import ThreadPoolExecutor from typing import ( Any, Callable, @@ -127,6 +128,49 @@ def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]: return inner +def ltr_metric_decorator(func: Callable, n_jobs: Optional[int]) -> Metric: + """Decorate a learning to rank metric.""" + + def inner(y_score: np.ndarray, dmatrix: DMatrix) -> Tuple[str, float]: + y_true = dmatrix.get_label() + group_ptr = dmatrix.get_uint_info("group_ptr") + if group_ptr.size < 2: + raise ValueError( + "Invalid `group_ptr`. Likely caused by invalid qid or group." + ) + scores = np.empty(group_ptr.size - 1) + futures = [] + weight = dmatrix.get_group() + no_weight = weight.size == 0 + + def task(i: int) -> float: + begin = group_ptr[i - 1] + end = group_ptr[i] + gy = y_true[begin:end] + gp = y_score[begin:end] + if gy.size == 1: + # Maybe there's a better default? 1.0 because many ranking score + # functions have output in range [0, 1]. + return 1.0 + return func(gy, gp) + + workers = n_jobs if n_jobs is not None else os.cpu_count() + with ThreadPoolExecutor(max_workers=workers) as executor: + for i in range(1, group_ptr.size): + f = executor.submit(task, i) + futures.append(f) + + for i, f in enumerate(futures): + scores[i] = f.result() + + if no_weight: + return func.__name__, scores.mean() + + return func.__name__, np.average(scores, weights=weight) + + return inner + + __estimator_doc = """ n_estimators : int Number of gradient boosted trees. Equivalent to number of boosting @@ -868,7 +912,10 @@ def _duplicated(parameter: str) -> None: metric = eval_metric elif callable(eval_metric): # Parameter from constructor or set_params - metric = _metric_decorator(eval_metric) + if self._get_type() == "ranker": + metric = ltr_metric_decorator(eval_metric, self.n_jobs) + else: + metric = _metric_decorator(eval_metric) else: params.update({"eval_metric": eval_metric}) @@ -1979,10 +2026,6 @@ def fit( ) = self._configure_fit( xgb_model, eval_metric, params, early_stopping_rounds, callbacks ) - if callable(metric): - raise ValueError( - "Custom evaluation metric is not yet supported for XGBRanker." - ) self._Booster = train( params, diff --git a/tests/python/test_with_sklearn.py b/tests/python/test_with_sklearn.py index 4f627cd34e85..55e14ae97e40 100644 --- a/tests/python/test_with_sklearn.py +++ b/tests/python/test_with_sklearn.py @@ -154,6 +154,32 @@ def test_ranking(): np.testing.assert_almost_equal(pred, pred_orig) +def test_ranking_metric() -> None: + from sklearn.metrics import roc_auc_score + + X, y, qid, w = tm.make_ltr(512, 4, 3, 2) + # use auc for test as ndcg_score in sklearn works only on label gain instead of exp + # gain. + # note that the auc in sklearn is different from the one in XGBoost. The one in + # sklearn compares the number of mis-classified docs, while the one in xgboost + # compares the number of mis-classified pairs. + ltr = xgb.XGBRanker( + eval_metric=roc_auc_score, n_estimators=10, tree_method="hist", max_depth=2 + ) + ltr.fit( + X, + y, + qid=qid, + sample_weight=w, + eval_set=[(X, y)], + eval_qid=[qid], + sample_weight_eval_set=[w], + verbose=True, + ) + results = ltr.evals_result() + assert results["validation_0"]["roc_auc_score"][-1] > 0.6 + + def test_stacking_regression(): from sklearn.datasets import load_diabetes from sklearn.ensemble import RandomForestRegressor, StackingRegressor @@ -1426,10 +1452,10 @@ def test_weighted_evaluation_metric(): X_train, X_test = X[:1600], X[1600:] y_train, y_test = y[:1600], y[1600:] weights_eval_set = np.random.choice([1, 2], len(X_test)) - + np.random.seed(0) weights_train = np.random.choice([1, 2], len(X_train)) - + clf = xgb.XGBClassifier( tree_method="hist", eval_metric=log_loss,