Skip to content

Commit

Permalink
Support sklearn cross validation for ranker. (#8859)
Browse files Browse the repository at this point in the history
* Support sklearn cross validation for ranker.

- Add a convention for X to include a special `qid` column.

sklearn utilities consider only `X`, `y` and `sample_weight` for supervised learning
algorithms, but we need an additional qid array for ranking.

It's important to be able to support the cross validation function in sklearn since all
other tuning functions like grid search are based on cross validation.
  • Loading branch information
trivialfis authored Mar 6, 2023
1 parent cad7401 commit 7eba285
Show file tree
Hide file tree
Showing 8 changed files with 232 additions and 43 deletions.
14 changes: 8 additions & 6 deletions python-package/xgboost/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
import numpy

from . import collective
from .core import Booster, DMatrix, XGBoostError, _get_booster_layer_trees
from .core import (
Booster,
DMatrix,
XGBoostError,
_get_booster_layer_trees,
_parse_eval_str,
)

__all__ = [
"TrainingCallback",
Expand Down Expand Up @@ -250,11 +256,7 @@ def after_iteration(
for _, name in evals:
assert name.find("-") == -1, "Dataset name should not contain `-`"
score: str = model.eval_set(evals, epoch, self.metric, self._output_margin)
splited = score.split()[1:] # into datasets
# split up `test-error:0.1234`
metric_score_str = [tuple(s.split(":")) for s in splited]
# convert to float
metric_score = [(n, float(s)) for n, s in metric_score_str]
metric_score = _parse_eval_str(score)
self._update_history(metric_score, epoch)
ret = any(c.after_iteration(model, epoch, self.history) for c in self.callbacks)
return ret
Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def allreduce(data: np.ndarray, op: Op) -> np.ndarray: # pylint:disable=invalid
if buf.base is data.base:
buf = buf.copy()
if buf.dtype not in DTYPE_ENUM__:
raise Exception(f"data type {buf.dtype} not supported")
raise TypeError(f"data type {buf.dtype} not supported")
_check_call(
_LIB.XGCommunicatorAllreduce(
buf.ctypes.data_as(ctypes.c_void_p),
Expand Down
10 changes: 10 additions & 0 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,16 @@ def make_jcargs(**kwargs: Any) -> bytes:
return from_pystr_to_cstr(json.dumps(kwargs))


def _parse_eval_str(result: str) -> List[Tuple[str, float]]:
"""Parse an eval result string from the booster."""
splited = result.split()[1:]
# split up `test-error:0.1234`
metric_score_str = [tuple(s.split(":")) for s in splited]
# convert to float
metric_score = [(n, float(s)) for n, s in metric_score_str]
return metric_score


IterRange = TypeVar("IterRange", Optional[Tuple[int, int]], Tuple[int, int])


Expand Down
2 changes: 1 addition & 1 deletion python-package/xgboost/rabit.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def allreduce( # pylint:disable=invalid-name
"""
if prepare_fun is None:
return collective.allreduce(data, collective.Op(op))
raise Exception("preprocessing function is no longer supported")
raise ValueError("preprocessing function is no longer supported")


def version_number() -> int:
Expand Down
159 changes: 124 additions & 35 deletions python-package/xgboost/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@
XGBoostError,
_convert_ntree_limit,
_deprecate_positional_args,
_parse_eval_str,
)
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array
from .data import _is_cudf_df, _is_cudf_ser, _is_cupy_array, _is_pandas_df
from .training import train


Expand Down Expand Up @@ -1812,32 +1813,43 @@ def fit(
return self


def _get_qid(
X: ArrayLike, qid: Optional[ArrayLike]
) -> Tuple[ArrayLike, Optional[ArrayLike]]:
"""Get the special qid column from X if exists."""
if (_is_pandas_df(X) or _is_cudf_df(X)) and hasattr(X, "qid"):
if qid is not None:
raise ValueError(
"Found both the special column `qid` in `X` and the `qid` from the"
"`fit` method. Please remove one of them."
)
q_x = X.qid
X = X.drop("qid", axis=1)
return X, q_x
return X, qid


@xgboost_model_doc(
"Implementation of the Scikit-Learn API for XGBoost Ranking.",
"""Implementation of the Scikit-Learn API for XGBoost Ranking.""",
["estimators", "model"],
end_note="""
.. note::
The default objective for XGBRanker is "rank:pairwise"
.. note::
A custom objective function is currently not supported by XGBRanker.
Likewise, a custom metric function is not supported either.
.. note::
Query group information is required for ranking tasks by either using the
`group` parameter or `qid` parameter in `fit` method. This information is
not required in 'predict' method and multiple groups can be predicted on
a single call to `predict`.
Query group information is only required for ranking training but not
prediction. Multiple groups can be predicted on a single call to
:py:meth:`predict`.
When fitting the model with the `group` parameter, your data need to be sorted
by query group first. `group` must be an array that contains the size of each
by the query group first. `group` is an array that contains the size of each
query group.
When fitting the model with the `qid` parameter, your data does not need
sorting. `qid` must be an array that contains the group of each training
sample.
Similarly, when fitting the model with the `qid` parameter, the data should be
sorted according to query index and `qid` is an array that contains the query
index for each training sample.
For example, if your original data look like:
Expand All @@ -1859,9 +1871,10 @@ def fit(
| 2 | 1 | x_7 |
+-------+-----------+---------------+
then `fit` method can be called with either `group` array as ``[3, 4]``
or with `qid` as ``[`1, 1, 1, 2, 2, 2, 2]``, that is the qid column.
""",
then :py:meth:`fit` method can be called with either `group` array as ``[3, 4]``
or with `qid` as ``[1, 1, 1, 2, 2, 2, 2]``, that is the qid column. Also, the
`qid` can be a special column of input `X` instead of a separated parameter, see
:py:meth:`fit` for more info.""",
)
class XGBRanker(XGBModel, XGBRankerMixIn):
# pylint: disable=missing-docstring,too-many-arguments,invalid-name
Expand All @@ -1873,6 +1886,16 @@ def __init__(self, *, objective: str = "rank:pairwise", **kwargs: Any):
if "rank:" not in objective:
raise ValueError("please use XGBRanker for ranking task")

def _create_ltr_dmatrix(
self, ref: Optional[DMatrix], data: ArrayLike, qid: ArrayLike, **kwargs: Any
) -> DMatrix:
data, qid = _get_qid(data, qid)

if kwargs.get("group", None) is None and qid is None:
raise ValueError("Either `group` or `qid` is required for ranking task")

return super()._create_dmatrix(ref=ref, data=data, qid=qid, **kwargs)

@_deprecate_positional_args
def fit(
self,
Expand Down Expand Up @@ -1907,6 +1930,23 @@ def fit(
X :
Feature matrix. See :ref:`py-data` for a list of supported types.
When this is a :py:class:`pandas.DataFrame` or a :py:class:`cudf.DataFrame`,
it may contain a special column called ``qid`` for specifying the query
index. Using a special column is the same as using the `qid` parameter,
except for being compatible with sklearn utility functions like
:py:func:`sklearn.model_selection.cross_validation`. The same convention
applies to the :py:meth:`XGBRanker.score` and :py:meth:`XGBRanker.predict`.
+-----+----------------+----------------+
| qid | feat_0 | feat_1 |
+-----+----------------+----------------+
| 0 | :math:`x_{00}` | :math:`x_{01}` |
+-----+----------------+----------------+
| 1 | :math:`x_{10}` | :math:`x_{11}` |
+-----+----------------+----------------+
| 1 | :math:`x_{20}` | :math:`x_{21}` |
+-----+----------------+----------------+
When the ``tree_method`` is set to ``hist`` or ``gpu_hist``, internally, the
:py:class:`QuantileDMatrix` will be used instead of the :py:class:`DMatrix`
for conserving memory. However, this has performance implications when the
Expand All @@ -1916,21 +1956,22 @@ def fit(
y :
Labels
group :
Size of each query group of training data. Should have as many elements as the
query groups in the training data. If this is set to None, then user must
provide qid.
Size of each query group of training data. Should have as many elements as
the query groups in the training data. If this is set to None, then user
must provide qid.
qid :
Query ID for each training sample. Should have the size of n_samples. If
this is set to None, then user must provide group.
this is set to None, then user must provide group or a special column in X.
sample_weight :
Query group weights
.. note:: Weights are per-group for ranking tasks
In ranking task, one weight is assigned to each query group/id (not each
data point). This is because we only care about the relative ordering of
data points within each group, so it doesn't make sense to assign weights
to individual data points.
data points within each group, so it doesn't make sense to assign
weights to individual data points.
base_margin :
Global bias for each instance.
eval_set :
Expand All @@ -1942,7 +1983,8 @@ def fit(
query groups in the ``i``-th pair in **eval_set**.
eval_qid :
A list in which ``eval_qid[i]`` is the array containing query ID of ``i``-th
pair in **eval_set**.
pair in **eval_set**. The special column convention in `X` applies to
validation datasets as well.
eval_metric : str, list of str, optional
.. deprecated:: 1.6.0
Expand Down Expand Up @@ -1985,16 +2027,7 @@ def fit(
Use `callbacks` in :py:meth:`__init__` or :py:meth:`set_params` instead.
"""
# check if group information is provided
with config_context(verbosity=self.verbosity):
if group is None and qid is None:
raise ValueError("group or qid is required for ranking task")

if eval_set is not None:
if eval_group is None and eval_qid is None:
raise ValueError(
"eval_group or eval_qid is required if eval_set is not None"
)
train_dmatrix, evals = _wrap_evaluation_matrices(
missing=self.missing,
X=X,
Expand All @@ -2009,7 +2042,7 @@ def fit(
base_margin_eval_set=base_margin_eval_set,
eval_group=eval_group,
eval_qid=eval_qid,
create_dmatrix=self._create_dmatrix,
create_dmatrix=self._create_ltr_dmatrix,
enable_categorical=self.enable_categorical,
feature_types=self.feature_types,
)
Expand Down Expand Up @@ -2044,3 +2077,59 @@ def fit(

self._set_evaluation_result(evals_result)
return self

def predict(
self,
X: ArrayLike,
output_margin: bool = False,
ntree_limit: Optional[int] = None,
validate_features: bool = True,
base_margin: Optional[ArrayLike] = None,
iteration_range: Optional[Tuple[int, int]] = None,
) -> ArrayLike:
X, _ = _get_qid(X, None)
return super().predict(
X,
output_margin,
ntree_limit,
validate_features,
base_margin,
iteration_range,
)

def apply(
self,
X: ArrayLike,
ntree_limit: int = 0,
iteration_range: Optional[Tuple[int, int]] = None,
) -> ArrayLike:
X, _ = _get_qid(X, None)
return super().apply(X, ntree_limit, iteration_range)

def score(self, X: ArrayLike, y: ArrayLike) -> float:
"""Evaluate score for data using the last evaluation metric.
Parameters
----------
X : pd.DataFrame|cudf.DataFrame
Feature matrix. A DataFrame with a special `qid` column.
y :
Labels
Returns
-------
score :
The result of the first evaluation metric for the ranker.
"""
X, qid = _get_qid(X, None)
Xyq = DMatrix(X, y, qid=qid)
if callable(self.eval_metric):
metric = ltr_metric_decorator(self.eval_metric, self.n_jobs)
result_str = self.get_booster().eval_set([(Xyq, "eval")], feval=metric)
else:
result_str = self.get_booster().eval(Xyq)

metric_score = _parse_eval_str(result_str)
return metric_score[-1][1]
72 changes: 72 additions & 0 deletions python-package/xgboost/testing/ranking.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# pylint: disable=too-many-locals
"""Tests for learning to rank."""
from types import ModuleType
from typing import Any

import numpy as np
import pytest

import xgboost as xgb
from xgboost import testing as tm


def run_ranking_qid_df(impl: ModuleType, tree_method: str) -> None:
"""Test ranking with qid packed into X."""
import scipy.sparse
from sklearn.metrics import mean_squared_error
from sklearn.model_selection import StratifiedGroupKFold, cross_val_score

X, y, q, _ = tm.make_ltr(n_samples=128, n_features=2, n_query_groups=8, max_rel=3)

# pack qid into x using dataframe
df = impl.DataFrame(X)
df["qid"] = q
ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg", tree_method=tree_method)
ranker.fit(df, y)
s = ranker.score(df, y)
assert s > 0.7

# works with validation datasets as well
valid_df = df.copy()
valid_df.iloc[0, 0] = 3.0
ranker.fit(df, y, eval_set=[(valid_df, y)])

# same as passing qid directly
ranker = xgb.XGBRanker(n_estimators=3, eval_metric="ndcg", tree_method=tree_method)
ranker.fit(X, y, qid=q)
s1 = ranker.score(df, y)
assert np.isclose(s, s1)

# Works with standard sklearn cv
if tree_method != "gpu_hist":
# we need cuML for this.
kfold = StratifiedGroupKFold(shuffle=False)
results = cross_val_score(ranker, df, y, cv=kfold, groups=df.qid)
assert len(results) == 5

# Works with custom metric
def neg_mse(*args: Any, **kwargs: Any) -> float:
return -float(mean_squared_error(*args, **kwargs))

ranker = xgb.XGBRanker(n_estimators=3, eval_metric=neg_mse, tree_method=tree_method)
ranker.fit(df, y, eval_set=[(valid_df, y)])
score = ranker.score(valid_df, y)
assert np.isclose(score, ranker.evals_result()["validation_0"]["neg_mse"][-1])

# Works with sparse data
if tree_method != "gpu_hist":
# no sparse with cuDF
X_csr = scipy.sparse.csr_matrix(X)
df = impl.DataFrame.sparse.from_spmatrix(
X_csr, columns=[str(i) for i in range(X.shape[1])]
)
df["qid"] = q
ranker = xgb.XGBRanker(
n_estimators=3, eval_metric="ndcg", tree_method=tree_method
)
ranker.fit(df, y)
s2 = ranker.score(df, y)
assert np.isclose(s2, s)

with pytest.raises(ValueError, match="Either `group` or `qid`."):
ranker.fit(df, y, eval_set=[(X, y)])
Loading

0 comments on commit 7eba285

Please sign in to comment.