Skip to content

Commit

Permalink
add CatBoostRanker support (mlflow#6032)
Browse files Browse the repository at this point in the history
* Add CatBoostRanker support

Signed-off-by: Daniil Gafni <[email protected]>

* sorted group_id in tests

Signed-off-by: Daniil Gafni <[email protected]>

* fixed typo

Signed-off-by: Daniil Gafni <[email protected]>

* fixed typo

Signed-off-by: Daniil Gafni <[email protected]>

* add separate catboost_ranker test

Signed-off-by: Daniil Gafni <[email protected]>

* fix some issues

Signed-off-by: Daniil Gafni <[email protected]>

* Autoformat: https://github.com/mlflow/mlflow/actions/runs/2454908292

Signed-off-by: mlflow-automation <[email protected]>

* maybe import CatBoostRanker

Signed-off-by: Daniil Gafni <[email protected]>

* skip -> skipif

Signed-off-by: Daniil Gafni <[email protected]>

* use __name__ attribute

Signed-off-by: Daniil Gafni <[email protected]>

* packaing.version

Signed-off-by: Daniil Gafni <[email protected]>

* fix lint issues

Signed-off-by: harupy <[email protected]>

* fix test

Signed-off-by: harupy <[email protected]>

* simplify test

Signed-off-by: harupy <[email protected]>

Co-authored-by: mlflow-automation <[email protected]>
Co-authored-by: harupy <[email protected]>
Signed-off-by: Michal Karzynski <[email protected]>
  • Loading branch information
3 people authored and postrational committed Jul 27, 2022
1 parent ce88320 commit 59b6a82
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 3 deletions.
15 changes: 12 additions & 3 deletions mlflow/catboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
https://catboost.ai/docs/concepts/python-reference_catboost_save_model.html
.. _CatBoostClassifier:
https://catboost.ai/docs/concepts/python-reference_catboostclassifier.html
.. _CatBoostRanker:
https://catboost.ai/docs/concepts/python-reference_catboostranker.html
.. _CatBoostRegressor:
https://catboost.ai/docs/concepts/python-reference_catboostregressor.html
"""
Expand Down Expand Up @@ -90,7 +92,7 @@ def save_model(
Save a CatBoost model to a path on the local file system.
:param cb_model: CatBoost model (an instance of `CatBoost`_, `CatBoostClassifier`_,
or `CatBoostRegressor`_) to be saved.
`CatBoostRanker`_, or `CatBoostRegressor`_) to be saved.
:param path: Local path where the model is to be saved.
:param conda_env: {{ conda_env }}
:param code_paths: A list of local filesystem paths to Python file dependencies (or directories
Expand Down Expand Up @@ -209,7 +211,7 @@ def log_model(
Log a CatBoost model as an MLflow artifact for the current run.
:param cb_model: CatBoost model (an instance of `CatBoost`_, `CatBoostClassifier`_,
or `CatBoostRegressor`_) to be saved.
`CatBoostRanker`_, or `CatBoostRegressor`_) to be saved.
:param artifact_path: Run-relative artifact path.
:param conda_env: {{ conda_env }}
:param code_paths: A list of local filesystem paths to Python file dependencies (or directories
Expand Down Expand Up @@ -269,6 +271,13 @@ def _init_model(model_type):

model_types = {c.__name__: c for c in [CatBoost, CatBoostClassifier, CatBoostRegressor]}

try:
from catboost import CatBoostRanker

model_types[CatBoostRanker.__name__] = CatBoostRanker
except ImportError:
pass

if model_type not in model_types:
raise TypeError(
"Invalid model type: '{}'. Must be one of {}".format(
Expand Down Expand Up @@ -317,7 +326,7 @@ def load_model(model_uri, dst_path=None):
This directory must already exist. If unspecified, a local output
path will be created.
:return: A CatBoost model (an instance of `CatBoost`_, `CatBoostClassifier`_,
:return: A CatBoost model (an instance of `CatBoost`_, `CatBoostClassifier`_, `CatBoostRanker`_,
or `CatBoostRegressor`_)
"""
local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path)
Expand Down
27 changes: 27 additions & 0 deletions tests/catboost/test_catboost_model_export.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from collections import namedtuple
from unittest import mock
from packaging.version import Version
import os
import pytest
import yaml
Expand Down Expand Up @@ -91,6 +92,32 @@ def test_init_model(model_type):
assert model.__class__.__name__ == model_type


@pytest.mark.skipif(
Version(cb.__version__) < Version("0.26.0"),
reason="catboost < 0.26.0 does not support CatBoostRanker",
)
def test_log_catboost_ranker():
"""
This is a separate test for the CatBoostRanker model.
It is separate since the ranking task requires a group_id column which makes the code different.
"""
# the ranking task requires setting a group_id
# we are creating a dummy group_id here that doesn't make any sense for the Iris dataset,
# but is ok for testing if the code is running correctly
X, y = get_iris()
dummy_group_id = np.arange(len(X)) % 3
dummy_group_id.sort()

model = cb.CatBoostRanker(**MODEL_PARAMS, subsample=1.0)
model.fit(X, y, group_id=dummy_group_id)

with mlflow.start_run():
model_info = mlflow.catboost.log_model(model, "model")
loaded_model = mlflow.catboost.load_model(model_info.model_uri)
assert isinstance(loaded_model, cb.CatBoostRanker)
np.testing.assert_array_almost_equal(model.predict(X), loaded_model.predict(X))


def test_init_model_throws_for_invalid_model_type():
with pytest.raises(TypeError, match="Invalid model type"):
mlflow.catboost._init_model("unsupported")
Expand Down

0 comments on commit 59b6a82

Please sign in to comment.