diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 77dd8256..cdc85e91 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -4,6 +4,7 @@ from typing import Any, Iterable, List, Literal, Optional, Tuple, TypeVar, Union + try: from typing import Self # type: ignore[attr-defined] except ImportError: @@ -17,7 +18,7 @@ from sklearn.base import clone from sklearn.pipeline import Pipeline as _Pipeline from sklearn.pipeline import _final_estimator_has, _fit_transform_one -from sklearn.utils import Bunch, _print_elapsed_time +from sklearn.utils import Bunch from sklearn.utils.metadata_routing import ( _routing_enabled, # pylint: disable=protected-access ) @@ -32,6 +33,7 @@ PostPredictionTransformation, PostPredictionWrapper, ) +from molpipeline.utils.logging import print_elapsed_time from molpipeline.utils.molpipeline_types import ( AnyElement, AnyPredictor, @@ -240,7 +242,7 @@ def _fit( for step in self._iter(with_final=False, filter_passthrough=False): step_idx, name, transformer = step if transformer is None or transformer == "passthrough": - with _print_elapsed_time("Pipeline", self._log_message(step_idx)): + with print_elapsed_time("Pipeline", self._log_message(step_idx)): continue if hasattr(memory, "location") and memory.location is None: @@ -457,7 +459,7 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self: """ routed_params = self._check_method_params(method="fit", props=fit_params) Xt, yt = self._fit(X, y, routed_params) # pylint: disable=invalid-name - with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): if self._final_estimator != "passthrough": if is_empty(Xt): logger.warning( @@ -528,7 +530,7 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any: routed_params = self._check_method_params(method="fit_transform", props=params) iter_input, iter_label = self._fit(X, y, routed_params) last_step = self._final_estimator - with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): if last_step == "passthrough": pass elif is_empty(iter_input): @@ -648,7 +650,7 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any: ) # pylint: disable=invalid-name params_last_step = routed_params[self.steps[-1][0]] - with _print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): if self._final_estimator == "passthrough": y_pred = iter_input elif is_empty(iter_input): diff --git a/molpipeline/utils/logging.py b/molpipeline/utils/logging.py new file mode 100644 index 00000000..fbf7fa77 --- /dev/null +++ b/molpipeline/utils/logging.py @@ -0,0 +1,81 @@ +"""Logging helper functions.""" + +from __future__ import annotations + +import timeit +from contextlib import contextmanager +from typing import Generator + +from loguru import logger + + +def _message_with_time(source: str, message: str, time: float) -> str: + """Create one line message for logging purposes. + + Adapted from sklearn's function to stay consistent with the logging style: + https://github.com/scikit-learn/scikit-learn/blob/e16a6ddebd527e886fc22105710ee20ce255f9f0/sklearn/utils/_user_interface.py + + Parameters + ---------- + source : str + String indicating the source or the reference of the message. + message : str + Short message. + time : float + Time in seconds. + + Returns + ------- + str + Message with elapsed time. + """ + start_message = f"[{source}] " + + # adapted from joblib.logger.short_format_time without the Windows -.1s + # adjustment + if time > 60: + time_str = f"{(time / 60):4.1f}min" + else: + time_str = f" {time:5.1f}s" + + end_message = f" {message}, total={time_str}" + dots_len = 70 - len(start_message) - len(end_message) + return f"{start_message}{dots_len * '.'}{end_message}" + + +@contextmanager +def print_elapsed_time( + source: str, message: str | None = None, use_logger: bool = False +) -> Generator[None, None, None]: + """Log elapsed time to stdout when the context is exited. + + Adapted from sklearn's function to stay consistent with the logging style: + https://github.com/scikit-learn/scikit-learn/blob/e16a6ddebd527e886fc22105710ee20ce255f9f0/sklearn/utils/_user_interface.py + + Parameters + ---------- + source : str + String indicating the source or the reference of the message. + message : str, default=None + Short message. If None, nothing will be printed. + use_logger : bool, default=False + If True, the message will be logged using the logger. + + Returns + ------- + context_manager + Prints elapsed time upon exit if verbose. + """ + if message is None: + yield + else: + start = timeit.default_timer() + yield + message_to_print = _message_with_time( + source, message, timeit.default_timer() - start + ) + + if use_logger: + logger.info(message_to_print) + else: + print(message_to_print) diff --git a/tests/test_utils/test_logging.py b/tests/test_utils/test_logging.py new file mode 100644 index 00000000..f737119c --- /dev/null +++ b/tests/test_utils/test_logging.py @@ -0,0 +1,34 @@ +"""Test logging utils.""" + +import io +import unittest +from contextlib import redirect_stdout + +from molpipeline.utils.logging import print_elapsed_time + + +class LoggingUtilsTest(unittest.TestCase): + """Unittest for conversion of sklearn models to json and back.""" + + def test__print_elapsed_time(self) -> None: + """Test message logging with timings work as expected.""" + + # when message is None nothing should be printed + stream1 = io.StringIO() + with redirect_stdout(stream1): + with print_elapsed_time("source", message=None, use_logger=False): + pass + output1 = stream1.getvalue() + self.assertEqual(output1, "") + + # message should be printed in the expected sklearn format + stream2 = io.StringIO() + with redirect_stdout(stream2): + with print_elapsed_time("source", message="my message", use_logger=False): + pass + output2 = stream2.getvalue() + self.assertTrue( + output2.startswith( + "[source] ................................... my message, total=" + ) + )