diff --git a/mlflow/catboost.py b/mlflow/catboost.py index a3af6e1fd3634..cb9eb50b7d4ff 100644 --- a/mlflow/catboost.py +++ b/mlflow/catboost.py @@ -13,6 +13,8 @@ https://catboost.ai/docs/concepts/python-reference_catboost_save_model.html .. _CatBoostClassifier: https://catboost.ai/docs/concepts/python-reference_catboostclassifier.html +.. _CatBoostRanker: + https://catboost.ai/docs/concepts/python-reference_catboostranker.html .. _CatBoostRegressor: https://catboost.ai/docs/concepts/python-reference_catboostregressor.html """ @@ -90,7 +92,7 @@ def save_model( Save a CatBoost model to a path on the local file system. :param cb_model: CatBoost model (an instance of `CatBoost`_, `CatBoostClassifier`_, - or `CatBoostRegressor`_) to be saved. + `CatBoostRanker`_, or `CatBoostRegressor`_) to be saved. :param path: Local path where the model is to be saved. :param conda_env: {{ conda_env }} :param code_paths: A list of local filesystem paths to Python file dependencies (or directories @@ -209,7 +211,7 @@ def log_model( Log a CatBoost model as an MLflow artifact for the current run. :param cb_model: CatBoost model (an instance of `CatBoost`_, `CatBoostClassifier`_, - or `CatBoostRegressor`_) to be saved. + `CatBoostRanker`_, or `CatBoostRegressor`_) to be saved. :param artifact_path: Run-relative artifact path. :param conda_env: {{ conda_env }} :param code_paths: A list of local filesystem paths to Python file dependencies (or directories @@ -269,6 +271,13 @@ def _init_model(model_type): model_types = {c.__name__: c for c in [CatBoost, CatBoostClassifier, CatBoostRegressor]} + try: + from catboost import CatBoostRanker + + model_types[CatBoostRanker.__name__] = CatBoostRanker + except ImportError: + pass + if model_type not in model_types: raise TypeError( "Invalid model type: '{}'. Must be one of {}".format( @@ -317,7 +326,7 @@ def load_model(model_uri, dst_path=None): This directory must already exist. If unspecified, a local output path will be created. - :return: A CatBoost model (an instance of `CatBoost`_, `CatBoostClassifier`_, + :return: A CatBoost model (an instance of `CatBoost`_, `CatBoostClassifier`_, `CatBoostRanker`_, or `CatBoostRegressor`_) """ local_model_path = _download_artifact_from_uri(artifact_uri=model_uri, output_path=dst_path) diff --git a/tests/catboost/test_catboost_model_export.py b/tests/catboost/test_catboost_model_export.py index 87e29389a0212..2c6381442240d 100644 --- a/tests/catboost/test_catboost_model_export.py +++ b/tests/catboost/test_catboost_model_export.py @@ -1,5 +1,6 @@ from collections import namedtuple from unittest import mock +from packaging.version import Version import os import pytest import yaml @@ -91,6 +92,32 @@ def test_init_model(model_type): assert model.__class__.__name__ == model_type +@pytest.mark.skipif( + Version(cb.__version__) < Version("0.26.0"), + reason="catboost < 0.26.0 does not support CatBoostRanker", +) +def test_log_catboost_ranker(): + """ + This is a separate test for the CatBoostRanker model. + It is separate since the ranking task requires a group_id column which makes the code different. + """ + # the ranking task requires setting a group_id + # we are creating a dummy group_id here that doesn't make any sense for the Iris dataset, + # but is ok for testing if the code is running correctly + X, y = get_iris() + dummy_group_id = np.arange(len(X)) % 3 + dummy_group_id.sort() + + model = cb.CatBoostRanker(**MODEL_PARAMS, subsample=1.0) + model.fit(X, y, group_id=dummy_group_id) + + with mlflow.start_run(): + model_info = mlflow.catboost.log_model(model, "model") + loaded_model = mlflow.catboost.load_model(model_info.model_uri) + assert isinstance(loaded_model, cb.CatBoostRanker) + np.testing.assert_array_almost_equal(model.predict(X), loaded_model.predict(X)) + + def test_init_model_throws_for_invalid_model_type(): with pytest.raises(TypeError, match="Invalid model type"): mlflow.catboost._init_model("unsupported")