Skip to content

Commit

Permalink
Add autologging for scikit-learn (mlflow#3287)
Browse files Browse the repository at this point in the history
* Add autologging for scikit-learn

Signed-off-by: harupy <[email protected]>

* Update sklearn's version

Signed-off-by: harupy <[email protected]>

* Remove unrelated file

Signed-off-by: harupy <[email protected]>

* rename

Signed-off-by: harupy <[email protected]>

* rename load_model

Signed-off-by: harupy <[email protected]>

* Remove blank line

Signed-off-by: harupy <[email protected]>

* Reorder imports

Signed-off-by: harupy <[email protected]>

* fix

Signed-off-by: harupy <[email protected]>

* DRY

Signed-off-by: harupy <[email protected]>

* Emit warning on older versions of sklearn

Signed-off-by: harupy <[email protected]>

* Use warnings.warn

Signed-off-by: harupy <[email protected]>

* Remove unused argument

Signed-off-by: harupy <[email protected]>

* Revert changes on requirements

Signed-off-by: harupy <[email protected]>

* Use LooseVersion

Signed-off-by: harupy <[email protected]>

* Fix _get_all_estimators

Signed-off-by: harupy <[email protected]>

* rename vars

Signed-off-by: harupy <[email protected]>

* Use backported all_estimators

Signed-off-by: harupy <[email protected]>

* Fix lint errors

Signed-off-by: harupy <[email protected]>

* Remove print

Signed-off-by: harupy <[email protected]>

* simplify code

Signed-off-by: harupy <[email protected]>

* Create sklearn directory

Signed-off-by: harupy <[email protected]>

* Move _all_estimators to utils

Signed-off-by: harupy <[email protected]>

* Remove link

Signed-off-by: harupy <[email protected]>

* fix

Signed-off-by: harupy <[email protected]>

* Fix active_run_exists' condition

Signed-off-by: harupy <[email protected]>

* Verify no children

Signed-off-by: harupy <[email protected]>

* Add experiment_id

Signed-off-by: harupy <[email protected]>

* Specify stacklevel

Signed-off-by: harupy <[email protected]>

* Wrap fit with try-except

Signed-off-by: harupy <[email protected]>

* Remove use_caplog

Signed-off-by: harupy <[email protected]>

* rename test

Signed-off-by: harupy <[email protected]>

* Remove temp_tracking_uri

Signed-off-by: harupy <[email protected]>

* Remove unused imports

Signed-off-by: harupy <[email protected]>

* Add docstring for _all_estimators

Signed-off-by: harupy <[email protected]>

* Fix assertions

Signed-off-by: harupy <[email protected]>

* Wrap score with try-except

Signed-off-by: harupy <[email protected]>

* fix

Signed-off-by: harupy <[email protected]>

* Add log assertion to test_autolog_marks_run_as_failed_when_fit_fails

Signed-off-by: harupy <[email protected]>

* indent

Signed-off-by: harupy <[email protected]>

* simplify code

Signed-off-by: harupy <[email protected]>

* Add failure reasons

Signed-off-by: harupy <[email protected]>

* Fix assertion order

Signed-off-by: harupy <[email protected]>

* Assert metrics is empty when score fails

Signed-off-by: harupy <[email protected]>

* minor fix

Signed-off-by: harupy <[email protected]>

* Assert after with

Signed-off-by: harupy <[email protected]>

* Use readable class name

Signed-off-by: harupy <[email protected]>

* Throw when fit fails

Signed-off-by: harupy <[email protected]>

* Rename active_run_exists to should_start_run

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* fix if condition

Signed-off-by: harupy <[email protected]>

* pass sample_weight if both fit and score have it

Signed-off-by: harupy <[email protected]>

* Exclude property methods from patching

Signed-off-by: harupy <[email protected]>

* Chunk params to avoid hitting log_batch API limit

Signed-off-by: harupy <[email protected]>

* Fix args handling

Signed-off-by: harupy <[email protected]>

* Use all_estimators if sklearn.utils.all_estimators exists

Signed-off-by: harupy <[email protected]>

* Fix lint errors

Signed-off-by: harupy <[email protected]>

* Remove useless ()

Signed-off-by: harupy <[email protected]>

* Fix test name

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* Use model.fit if not parametrized

Signed-off-by: harupy <[email protected]>

* Temporarily add sklearn job

Signed-off-by: harupy <[email protected]>

* Add pytest

Signed-off-by: harupy <[email protected]>

* rerun

Signed-off-by: harupy <[email protected]>

* fix config

Signed-off-by: harupy <[email protected]>

* Fix install

Signed-off-by: harupy <[email protected]>

* do not run install-common.sh

Signed-off-by: harupy <[email protected]>

* Disable fail-fast

Signed-off-by: harupy <[email protected]>

* Remove sklearn.datasets

Signed-off-by: harupy <[email protected]>

* Add 0.22.2

Signed-off-by: harupy <[email protected]>

* Try print_changed_only=True

Signed-off-by: harupy <[email protected]>

* Truncate dict value

Signed-off-by: harupy <[email protected]>

* Add test for value truncation

Signed-off-by: harupy <[email protected]>

* De-hardcode tests

Signed-off-by: harupy <[email protected]>

* Remove set_config

Signed-off-by: harupy <[email protected]>

* Use try_mlflow_log for mlflow.end_run

Signed-off-by: harupy <[email protected]>

* Use try-catch for _all_estimators

Signed-off-by: harupy <[email protected]>

* Fix waring message for scoring error

Signed-off-by: harupy <[email protected]>

* Mark autolog as experimental

Signed-off-by: harupy <[email protected]>

* Mark tests for autolog as large

Signed-off-by: harupy <[email protected]>

* Add test_fit_takes_Xy_as_keyword_arguments

Signed-off-by: harupy <[email protected]>

* Emit warning message when truncating key or value

Signed-off-by: harupy <[email protected]>

* De-hardcode x and y names

Signed-off-by: harupy <[email protected]>

* Fix mangled-signatures link

Signed-off-by: harupy <[email protected]>

* Add comments

Signed-off-by: harupy <[email protected]>

* Add doc for _get_args_for_score

Signed-off-by: harupy <[email protected]>

* Add large option

Signed-off-by: harupy <[email protected]>

* Apply truncation to expected dict

Signed-off-by: harupy <[email protected]>

* DRY

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* De-hardcode sklearn version

Signed-off-by: harupy <[email protected]>

* Fix lint

Signed-off-by: harupy <[email protected]>

* De-hardcode model dir

Signed-off-by: harupy <[email protected]>

* Fix patch target

Signed-off-by: harupy <[email protected]>

* Use called_once_with

Signed-off-by: harupy <[email protected]>

* Add a new test case to test_both_fit_and_score_contain_sample_weight

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* Remove unused function

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* DRY

Signed-off-by: harupy <[email protected]>

* Fix func order

Signed-off-by: harupy <[email protected]>

* Fix lint

Signed-off-by: harupy <[email protected]>

* Capitalize x

Signed-off-by: harupy <[email protected]>

* Override unbound methods

Signed-off-by: harupy <[email protected]>

* Introduce throw_if_try_mlflow_log_has_emitted_warnings fixture

Signed-off-by: harupy <[email protected]>

* Fix test_fit_takes_Xy_as_keyword_arguments

Signed-off-by: harupy <[email protected]>

* Add assertions for logged data

Signed-off-by: harupy <[email protected]>

* Add assert_called_once_with to test_call_fit_with_arguments_score_does_not_accept

Signed-off-by: harupy <[email protected]>

* Split Xy

Signed-off-by: harupy <[email protected]>

* Move log_metric to else clause

Signed-off-by: harupy <[email protected]>

* Fix lint errros

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* Add docstring for autolog

Signed-off-by: harupy <[email protected]>

* Move pylint disable comments

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* Fix fixture for try_mlflow_log

Signed-off-by: harupy <[email protected]>

* Add test_autolog_does_not_throw_when_mlflow_logging_fails

Signed-off-by: harupy <[email protected]>

* Fix lint

Signed-off-by: harupy <[email protected]>

* Replace key_is_none with val_is_none

Signed-off-by: harupy <[email protected]>

* Use MAX_ENTITY_KEY_LENGTH

Signed-off-by: harupy <[email protected]>

* Add comment for _MIN_SKLEARN_VERSION

Signed-off-by: harupy <[email protected]>

* Enhance comment for prop methods exclusion

Signed-off-by: harupy <[email protected]>

* Add todo for wrap & patch

Signed-off-by: harupy <[email protected]>

* bump sklearn

Signed-off-by: harupy <[email protected]>

* Update action config

Signed-off-by: harupy <[email protected]>

* Rename is_old_version

Signed-off-by: harupy <[email protected]>

* test _is_supported_version

Signed-off-by: harupy <[email protected]>

* Fix command

Signed-off-by: harupy <[email protected]>

* Fix _is_supported_version

Signed-off-by: harupy <[email protected]>

* Add continue-on-error

Signed-off-by: harupy <[email protected]>

* Emit a warning if test fail on an unsupported version

Signed-off-by: harupy <[email protected]>

* Add warning step

Signed-off-by: harupy <[email protected]>

* Fix workflow

Signed-off-by: harupy <[email protected]>

* Use set +e

Signed-off-by: harupy <[email protected]>

* debug

Signed-off-by: harupy <[email protected]>

* Fix condition

Signed-off-by: harupy <[email protected]>

* Simplify workflow

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* Add comment on why include unsupported version

Signed-off-by: harupy <[email protected]>

* Update doc

Signed-off-by: harupy <[email protected]>

* black

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* fix comment

Signed-off-by: harupy <[email protected]>

* nit

Signed-off-by: harupy <[email protected]>

* Fix syntax

Signed-off-by: harupy <[email protected]>

* lint

Signed-off-by: harupy <[email protected]>
  • Loading branch information
harupy authored and dbczumar committed Aug 27, 2020
1 parent 95cf5e4 commit c0c7414
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 511 deletions.
238 changes: 9 additions & 229 deletions mlflow/sklearn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,15 +539,9 @@ def autolog():
- A fitted estimator (logged by :py:func:`mlflow.sklearn.log_model()`).
**How does autologging work for meta estimators?**
When a meta estimator (e.g. `Pipeline`_, `GridSearchCV`_) calls ``fit()``, it internally calls
``fit()`` on its child estimators. Autologging does NOT perform logging on these constituent
``fit()`` calls.
**Parameter search**
In addition to recording the information discussed above, autologging for parameter
search meta estimators (`GridSearchCV`_ and `RandomizedSearchCV`_) records child runs
with metrics for each set of explored parameters, as well as artifacts and parameters
for the best model (if available).
When a meta estimator (e.g. `Pipeline`_, `GridSearchCV`_) calls ``fit``, it internally calls
``fit`` on its child estimators. Autologging does NOT perform logging on these constituent
``fit``.
**Supported estimators**
All estimators obtained by `sklearn.utils.all_estimators`_ (including meta estimators).
Expand All @@ -561,9 +555,6 @@ def autolog():
.. _GridSearchCV:
https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.GridSearchCV.html
.. _RandomizedSearchCV:
https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html
**Example**
.. code-block:: python
Expand Down Expand Up @@ -663,55 +654,7 @@ def fit_mlflow(self, func_name, *args, **kwargs):

raise e

_log_posttraining_metadata(self, *args, **kwargs)

if should_start_run:
try_mlflow_log(mlflow.end_run)

return fit_output

def _log_pretraining_metadata(estimator, *args, **kwargs):
"""
Records metadata (e.g., params and tags) for a scikit-learn estimator prior to training.
This is intended to be invoked within a patched scikit-learn training routine
(e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active
MLflow run that can be referenced via the fluent Tracking API.
:param estimator: The scikit-learn estimator for which to log metadata.
:param args: The arguments passed to the scikit-learn training routine (e.g.,
`fit()`, `fit_transform()`, ...).
:param kwargs: The keyword arguments passed to the scikit-learn training routine.
"""
# Deep parameter logging includes parameters from children of a given
# estimator. For some meta estimators (e.g., pipelines), recording
# these parameters is desirable. For parameter search estimators,
# however, child estimators act as seeds for the parameter search
# process; accordingly, we avoid logging initial, untuned parameters
# for these seed estimators.
should_log_params_deeply = not _is_parameter_search_estimator(estimator)
# Chunk model parameters to avoid hitting the log_batch API limit
for chunk in _chunk_dict(
estimator.get_params(deep=should_log_params_deeply),
chunk_size=MAX_PARAMS_TAGS_PER_BATCH,
):
truncated = _truncate_dict(chunk, MAX_ENTITY_KEY_LENGTH, MAX_PARAM_VAL_LENGTH)
try_mlflow_log(mlflow.log_params, truncated)

try_mlflow_log(mlflow.set_tags, _get_estimator_info_tags(estimator))

def _log_posttraining_metadata(estimator, *args, **kwargs):
"""
Records metadata for a scikit-learn estimator after training has completed.
This is intended to be invoked within a patched scikit-learn training routine
(e.g., `fit()`, `fit_transform()`, ...) and assumes the existence of an active
MLflow run that can be referenced via the fluent Tracking API.
:param estimator: The scikit-learn estimator for which to log metadata.
:param args: The arguments passed to the scikit-learn training routine (e.g.,
`fit()`, `fit_transform()`, ...).
:param kwargs: The keyword arguments passed to the scikit-learn training routine.
"""
if hasattr(estimator, "score"):
if hasattr(self, "score"):
try:
score_args = _get_args_for_score(self.score, self.fit, args, kwargs)
training_score = self.score(*score_args)
Expand All @@ -727,45 +670,10 @@ def _log_posttraining_metadata(estimator, *args, **kwargs):

try_mlflow_log(log_model, self, artifact_path="model")

if _is_parameter_search_estimator(estimator):
if hasattr(estimator, "best_estimator_"):
try_mlflow_log(log_model, estimator.best_estimator_, artifact_path="best_estimator")

if hasattr(estimator, "best_params_"):
best_params = {
f"best_{param_name}": param_value
for param_name, param_value in estimator.best_params_.items()
}
try_mlflow_log(mlflow.log_params, best_params)

if hasattr(estimator, "cv_results_"):
try:
# Fetch environment-specific tags (e.g., user and source) to ensure that lineage
# information is consistent with the parent run
environment_tags = context_registry.resolve_tags()
_create_child_runs_for_parameter_search(
cv_estimator=estimator,
parent_run=mlflow.active_run(),
child_tags=environment_tags,
)
except Exception as e:
msg = (
"Encountered exception during creation of child runs for parameter search."
" Child runs may be missing. Exception: {}".format(str(e))
)
_logger.warning(msg)

try:
cv_results_df = pd.DataFrame.from_dict(estimator.cv_results_)
_log_parameter_search_results_as_artifact(
cv_results_df, mlflow.active_run().info.run_id
)
except Exception as e:
msg = (
"Failed to log parameter search results as an artifact."
" Exception: {}".format(str(e))
)
_logger.warning(msg)
if should_start_run:
try_mlflow_log(mlflow.end_run)

return fit_output

def patched_fit(self, func_name, *args, **kwargs):
"""
Expand All @@ -785,21 +693,8 @@ def f(self, *args, **kwargs):

return f

from sklearn.model_selection import GridSearchCV, RandomizedSearchCV

patch_settings = gorilla.Settings(allow_hit=True, store_hit=True)
try:
from sklearn.utils import all_estimators
except ImportError:
all_estimators = _all_estimators

_, estimators_to_patch = zip(*all_estimators())
# Ensure that relevant meta estimators (e.g. GridSearchCV, Pipeline) are selected
# for patching if they are not already included in the output of `all_estimators()`
estimators_to_patch = set(estimators_to_patch).union(
set(_get_meta_estimators_for_autologging())
)
for class_def in estimators_to_patch:
for _, class_def in _all_estimators():
for func_name in ["fit", "fit_transform", "fit_predict"]:
if hasattr(class_def, func_name):
original = getattr(class_def, func_name)
Expand Down Expand Up @@ -828,118 +723,3 @@ def f(self, *args, **kwargs):
patch_func = functools.wraps(original)(patch_func)
patch = gorilla.Patch(class_def, func_name, patch_func, settings=patch_settings)
gorilla.apply(patch)

def fit_predict_cv(self, *args, **kwargs):
return patched_fit_cv(self, 'fit_predict', *args, **kwargs)

def fit_transform_cv(self, *args, **kwargs):
return patched_fit_cv(self, 'fit_transform', *args, **kwargs)

def fit_cv(self, *args, **kwargs):
return patched_fit_cv(self, 'fit', *args, **kwargs)

def patched_fit_cv(self, fn_name, *args, **kwargs):
"""
To be applied to a sklearn model class that defines a `fit`
method and inherits from `BaseEstimator` (thereby defining
the `get_params()` method)
"""
with _SklearnTrainingSession(allow_children=False, clazz=self.__class__) as t:
if t.should_log():
return fit_mlflow_cv(self, fn_name, *args, **kwargs)
else:
original_fit = gorilla.get_original_attribute(self, fn_name)
return original_fit(*args, **kwargs)

def fit_mlflow_cv(self, fn_name, *args, **kwargs):
try_mlflow_log(mlflow.start_run, nested=True)
# Perform shallow parameter logging for hyperparameter search APIs (e.g., GridSearchCV
# and RandomizedSearchCV) to avoid logging superfluous parameters from the seed
# `estimator` constructor argument; we will log the set of optimal estimator
# parameters, if available, once training completes
try_mlflow_log(mlflow.log_params, self.get_params(deep=False))
try_mlflow_log(mlflow.set_tag, "estimator_name", self.__class__.__name__)
try_mlflow_log(mlflow.set_tag, "estimator_class", self.__class__)

original_fit = gorilla.get_original_attribute(self, fn_name)
fit_output = original_fit(*args, **kwargs)

if hasattr(self, 'score'):
try:
training_score = self.score(args[0], args[1])
try_mlflow_log(mlflow.log_metric, "training_score", training_score)
except Exception as e:
print("Failed to collect scoring metrics!")
print(e)

try_mlflow_log(log_model, self, artifact_path='model')
if hasattr(self, 'best_estimator_'):
try_mlflow_log(log_model, self.best_estimator_, artifact_path="best_estimator")

if hasattr(self, 'best_params_'):
best_params = {
f"best_{param_name}": param_value
for param_name, param_value in self.best_params_.items()
}
try_mlflow_log(mlflow.log_params, best_params)

_create_child_cv_runs(cv_estimator=self)

try_mlflow_log(mlflow.end_run)

def _create_child_cv_runs(cv_estimator):
from mlflow.tracking.client import MlflowClient
from mlflow.entities import Metric, RunTag, Param
import pandas as pd
import time
from numbers import Number

client = MlflowClient()
metrics_timestamp = int(time.time() * 1000)

cv_results_df = pd.DataFrame.from_dict(cv_estimator.cv_results_)
for _, result_row in cv_results_df.iterrows():
with mlflow.start_run(nested=True):
params = [
Param(str(key), str(value)) for key, value in result_row.get("params", {}).items()
]
metrics = {
Metric(
key=key,
value=value,
timestamp=metrics_timestamp,
step=0,
)
for key, value in result_row.iteritems()
# Parameters values are recorded twice in the set of search `cv_results`:
# once within a `params` column with dictionary values and once within
# a separate dataframe column that is created for each parameter. To prevent
# duplication of parameters, we log the consolidated values from the parameter
# dictionary column and filter out the other parameter-specific columns with
# names of the form `param_{param_name}`.
if not key.startswith("param") and isinstance(value, Number)
}
tags = [
RunTag(key, value) for key, value in {
"estimator_name": str(cv_estimator.estimator.__class__.__name__),
"estimator_class": str(cv_estimator.estimator.__class__),
}.items()
]

client.log_batch(
run_id=mlflow.active_run().info.run_id,
params=params,
metrics=metrics,
tags=tags,
)

for class_def in [GridSearchCV, RandomizedSearchCV]:
if hasattr(class_def, 'fit'):
patch = gorilla.Patch(class_def, 'fit', fit_cv, settings=patch_settings)
gorilla.apply(patch)
if hasattr(class_def, 'fit_transform'):
patch = gorilla.Patch(class_def, 'fit_transform_cv', fit_transform_cv, settings=patch_settings)
gorilla.apply(patch)
if hasattr(class_def, 'fit_predict'):
patch = gorilla.Patch(class_def, 'fit_predict_cv', fit_predict_cv, settings=patch_settings)
gorilla.apply(patch)
Loading

0 comments on commit c0c7414

Please sign in to comment.