diff --git a/autosklearn/automl.py b/autosklearn/automl.py index 278cd5c146..f76a03adec 100644 --- a/autosklearn/automl.py +++ b/autosklearn/automl.py @@ -21,7 +21,6 @@ import os import platform import sys -import tempfile import time import types import uuid @@ -37,7 +36,7 @@ import sklearn.utils from ConfigSpace.configuration_space import Configuration, ConfigurationSpace from ConfigSpace.read_and_write import json as cs_json -from dask.distributed import Client, LocalCluster +from dask.distributed import Client from scipy.sparse import spmatrix from sklearn.base import BaseEstimator from sklearn.dummy import DummyClassifier, DummyRegressor @@ -105,6 +104,7 @@ from autosklearn.pipeline.components.regression import RegressorChoice from autosklearn.smbo import AutoMLSMBO from autosklearn.util import RE_PATTERN, pipeline +from autosklearn.util.dask import Dask, LocalDask, UserDask from autosklearn.util.data import ( DatasetCompressionSpec, default_dataset_compression_arg, @@ -120,7 +120,6 @@ warnings_to, ) from autosklearn.util.parallel import preload_modules -from autosklearn.util.single_thread_client import SingleThreadedClient from autosklearn.util.smac_wrap import SMACCallback, SmacRunCallback from autosklearn.util.stopwatch import StopWatch @@ -299,21 +298,22 @@ def __init__( self._initial_configurations_via_metalearning = ( initial_configurations_via_metalearning ) + self._n_jobs = n_jobs self._scoring_functions = scoring_functions or [] self._resampling_strategy_arguments = resampling_strategy_arguments or {} + self._multiprocessing_context = "forkserver" # Single core, local runs should use fork to prevent the __main__ requirements # in examples. Nevertheless, multi-process runs have spawn as requirement to # reduce the possibility of a deadlock - if n_jobs == 1 and dask_client is None: - self._multiprocessing_context = "fork" - self._dask_client = SingleThreadedClient() - self._n_jobs = 1 + self._dask: Dask + if dask_client is not None: + self._dask = UserDask(client=dask_client) else: - self._multiprocessing_context = "forkserver" - self._dask_client = dask_client - self._n_jobs = n_jobs + self._dask = LocalDask(n_jobs=n_jobs) + if n_jobs == 1: + self._multiprocessing_context = "fork" # Create the backend self._backend: Backend = create( @@ -350,38 +350,6 @@ def __init__( self.num_run = 0 self.fitted = False - def _create_dask_client(self) -> None: - self._is_dask_client_internally_created = True - self._dask_client = Client( - LocalCluster( - n_workers=self._n_jobs, - processes=False, - threads_per_worker=1, - # We use the temporal directory to save the - # dask workers, because deleting workers takes - # more time than deleting backend directories - # This prevent an error saying that the worker - # file was deleted, so the client could not close - # the worker properly - local_directory=tempfile.gettempdir(), - # Memory is handled by the pynisher, not by the dask worker/nanny - memory_limit=0, - ), - # Heartbeat every 10s - heartbeat_interval=10000, - ) - - def _close_dask_client(self, force: bool = False) -> None: - if getattr(self, "_dask_client", None) is not None and ( - force or getattr(self, "_is_dask_client_internally_created", False) - ): - self._dask_client.shutdown() - self._dask_client.close() - del self._dask_client - self._dask_client = None - self._is_dask_client_internally_created = False - del self._is_dask_client_internally_created - def _get_logger(self, name: str) -> PicklableClientLogger: logger_name = "AutoML(%d):%s" % (self._seed, name) @@ -747,17 +715,6 @@ def fit( "autosklearn.metrics.Scorer." ) - # If no dask client was provided, we create one, so that we can - # start a ensemble process in parallel to smbo optimize - if self._dask_client is None and ( - self._ensemble_class is not None - or self._n_jobs is not None - and self._n_jobs > 1 - ): - self._create_dask_client() - else: - self._is_dask_client_internally_created = False - self._dataset_name = dataset_name self._stopwatch.start(self._dataset_name) @@ -902,70 +859,85 @@ def fit( ) n_meta_configs = self._initial_configurations_via_metalearning - _proc_smac = AutoMLSMBO( - config_space=self.configuration_space, - dataset_name=self._dataset_name, - backend=self._backend, - total_walltime_limit=time_left, - func_eval_time_limit=per_run_time_limit, - memory_limit=self._memory_limit, - data_memory_limit=self._data_memory_limit, - stopwatch=self._stopwatch, - n_jobs=self._n_jobs, - dask_client=self._dask_client, - start_num_run=self.num_run, - num_metalearning_cfgs=n_meta_configs, - config_file=configspace_path, - seed=self._seed, - metadata_directory=self._metadata_directory, - metrics=self._metrics, - resampling_strategy=self._resampling_strategy, - resampling_strategy_args=self._resampling_strategy_arguments, - include=self._include, - exclude=self._exclude, - disable_file_output=self._disable_evaluator_output, - get_smac_object_callback=self._get_smac_object_callback, - smac_scenario_args=self._smac_scenario_args, - scoring_functions=self._scoring_functions, - port=self._logger_port, - pynisher_context=self._multiprocessing_context, - ensemble_callback=proc_ensemble, - trials_callback=self._get_trials_callback, - ) + with self._dask as dask_client: + resamp_args = self._resampling_strategy_arguments + _proc_smac = AutoMLSMBO( + config_space=self.configuration_space, + dataset_name=self._dataset_name, + backend=self._backend, + total_walltime_limit=time_left, + func_eval_time_limit=per_run_time_limit, + memory_limit=self._memory_limit, + data_memory_limit=self._data_memory_limit, + stopwatch=self._stopwatch, + n_jobs=self._n_jobs, + dask_client=dask_client, + start_num_run=self.num_run, + num_metalearning_cfgs=n_meta_configs, + config_file=configspace_path, + seed=self._seed, + metadata_directory=self._metadata_directory, + metrics=self._metrics, + resampling_strategy=self._resampling_strategy, + resampling_strategy_args=resamp_args, + include=self._include, + exclude=self._exclude, + disable_file_output=self._disable_evaluator_output, + get_smac_object_callback=self._get_smac_object_callback, + smac_scenario_args=self._smac_scenario_args, + scoring_functions=self._scoring_functions, + port=self._logger_port, + pynisher_context=self._multiprocessing_context, + ensemble_callback=proc_ensemble, + trials_callback=self._get_trials_callback, + ) - ( - self.runhistory_, - self.trajectory_, - self._budget_type, - ) = _proc_smac.run_smbo() - trajectory_filename = os.path.join( - self._backend.get_smac_output_directory_for_run(self._seed), - "trajectory.json", - ) - saveable_trajectory = [ - list(entry[:2]) + [entry[2].get_dictionary()] + list(entry[3:]) - for entry in self.trajectory_ - ] - with open(trajectory_filename, "w") as fh: - json.dump(saveable_trajectory, fh) - - self._logger.info("Starting shutdown...") - # Wait until the ensemble process is finished to avoid shutting down - # while the ensemble builder tries to access the data - if proc_ensemble is not None: - self.ensemble_performance_history = list(proc_ensemble.history) - - if len(proc_ensemble.futures) > 0: - # Now we wait for the future to return as it cannot be cancelled - # while it is running: https://stackoverflow.com/a/49203129 - self._logger.info( - "Ensemble script still running, waiting for it to finish." - ) - result = proc_ensemble.futures.pop().result() - if result: - ensemble_history, _ = result - self.ensemble_performance_history.extend(ensemble_history) - self._logger.info("Ensemble script finished, continue shutdown.") + ( + self.runhistory_, + self.trajectory_, + self._budget_type, + ) = _proc_smac.run_smbo() + + trajectory_filename = os.path.join( + self._backend.get_smac_output_directory_for_run(self._seed), + "trajectory.json", + ) + saveable_trajectory = [ + list(entry[:2]) + + [entry[2].get_dictionary()] + + list(entry[3:]) + for entry in self.trajectory_ + ] + with open(trajectory_filename, "w") as fh: + json.dump(saveable_trajectory, fh) + + self._logger.info("Starting shutdown...") + # Wait until the ensemble process is finished to avoid shutting + # down while the ensemble builder tries to access the data + if proc_ensemble is not None: + self.ensemble_performance_history = list( + proc_ensemble.history + ) + + if len(proc_ensemble.futures) > 0: + # Now we wait for the future to return as it cannot be + # cancelled while it is running + # * https://stackoverflow.com/a/49203129 + self._logger.info( + "Ensemble script still running," + " waiting for it to finish." + ) + result = proc_ensemble.futures.pop().result() + + if result: + ensemble_history, _ = result + self.ensemble_performance_history.extend( + ensemble_history + ) + + self._logger.info( + "Ensemble script finished, continue shutdown." + ) # save the ensemble performance history file if len(self.ensemble_performance_history) > 0: @@ -1054,7 +1026,7 @@ def _log_fit_setup(self) -> None: self._logger.debug( " multiprocessing_context: %s", str(self._multiprocessing_context) ) - self._logger.debug(" dask_client: %s", str(self._dask_client)) + self._logger.debug(" dask_client: %s", str(self._dask)) self._logger.debug(" precision: %s", str(self.precision)) self._logger.debug( " disable_evaluator_output: %s", str(self._disable_evaluator_output) @@ -1090,7 +1062,6 @@ def __sklearn_is_fitted__(self) -> bool: def _fit_cleanup(self) -> None: self._logger.info("Closing the dask infrastructure") - self._close_dask_client() self._logger.info("Finished closing the dask infrastructure") # Clean up the logger @@ -1555,12 +1526,6 @@ def fit_ensemble( # Make sure that input is valid y = self.InputValidator.target_validator.transform(y) - # Create a client if needed - if self._dask_client is None: - self._create_dask_client() - else: - self._is_dask_client_internally_created = False - metrics = metrics if metrics is not None else self._metrics if not isinstance(metrics, Sequence): metrics = [metrics] @@ -1568,35 +1533,41 @@ def fit_ensemble( # Use the current thread to start the ensemble builder process # The function ensemble_builder_process will internally create a ensemble # builder in the provide dask client - manager = EnsembleBuilderManager( - start_time=time.time(), - time_left_for_ensembles=self._time_for_task, - backend=copy.deepcopy(self._backend), - dataset_name=dataset_name if dataset_name else self._dataset_name, - task=task if task else self._task, - metrics=metrics if metrics is not None else self._metrics, - ensemble_class=( - ensemble_class if ensemble_class is not None else self._ensemble_class - ), - ensemble_kwargs=( - ensemble_kwargs - if ensemble_kwargs is not None - else self._ensemble_kwargs - ), - ensemble_nbest=ensemble_nbest if ensemble_nbest else self._ensemble_nbest, - max_models_on_disc=self._max_models_on_disc, - seed=self._seed, - precision=precision if precision else self.precision, - max_iterations=1, - read_at_most=None, - memory_limit=self._memory_limit, - random_state=self._seed, - logger_port=self._logger_port, - pynisher_context=self._multiprocessing_context, - ) - manager.build_ensemble(self._dask_client) - future = manager.futures.pop() - result = future.result() + with self._dask as dask_client: + manager = EnsembleBuilderManager( + start_time=time.time(), + time_left_for_ensembles=self._time_for_task, + backend=copy.deepcopy(self._backend), + dataset_name=dataset_name if dataset_name else self._dataset_name, + task=task if task else self._task, + metrics=metrics if metrics is not None else self._metrics, + ensemble_class=( + ensemble_class + if ensemble_class is not None + else self._ensemble_class + ), + ensemble_kwargs=( + ensemble_kwargs + if ensemble_kwargs is not None + else self._ensemble_kwargs + ), + ensemble_nbest=ensemble_nbest + if ensemble_nbest + else self._ensemble_nbest, + max_models_on_disc=self._max_models_on_disc, + seed=self._seed, + precision=precision if precision else self.precision, + max_iterations=1, + read_at_most=None, + memory_limit=self._memory_limit, + random_state=self._seed, + logger_port=self._logger_port, + pynisher_context=self._multiprocessing_context, + ) + manager.build_ensemble(dask_client) + future = manager.futures.pop() + result = future.result() + if result is None: raise ValueError( "Error building the ensemble - please check the log file and command " @@ -1606,7 +1577,6 @@ def fit_ensemble( self._ensemble_class = ensemble_class self._load_models() - self._close_dask_client() return self def _load_models(self): @@ -2295,7 +2265,7 @@ def _create_search_space( def __getstate__(self) -> dict[str, Any]: # Cannot serialize a client! - self._dask_client = None + self._dask = None self.logging_server = None self.stop_logging_server = None return self.__dict__ @@ -2304,8 +2274,6 @@ def __del__(self) -> None: # Clean up the logger self._clean_logger() - self._close_dask_client() - class AutoMLClassifier(AutoML): diff --git a/autosklearn/util/dask.py b/autosklearn/util/dask.py new file mode 100644 index 0000000000..624fecfae9 --- /dev/null +++ b/autosklearn/util/dask.py @@ -0,0 +1,142 @@ +""" Provides simplified 2 use cases of dask that we consider + +1. A UserDask is when a user supplies a dask client, in which case +we don't close this down and leave it up to the user to control its lifetime. +2. A LocalDask is one we use when no user dask is supplied. In this case +we make sure to spin up and close down clients as needed. + +Both of these can be uniformly accessed as a context manager. + +.. code:: python + + # Locally controlled dask client + local_dask = LocalDask(n_jobs=2) + with local_dask as client: + # Do stuff with client + ... + + # `client` is shutdown properly + + # ---------------- + + # User controlled dask client + user_dask = UserDask(user_client) + + with user_dask as client: + # Do stuff with (client == user_client) + ... + + # `user_client` is still open and up to the user to close +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any + +import tempfile + +from dask.distributed import Client, LocalCluster + +from autosklearn.util.single_thread_client import SingleThreadedClient + + +class Dask(ABC): + @abstractmethod + def client(self) -> Client: + """Should return a dask client""" + ... + + @abstractmethod + def close(self) -> None: + """Should close up any resources needed for the dask client""" + ... + + def __enter__(self) -> Client: + return self.client() + + def __exit__(self, *args: Any, **kwargs: Any) -> None: + self.close() + + @abstractmethod + def __repr__(self) -> str: + ... + + +class UserDask(Dask): + """A dask instance created by a user""" + + def __init__(self, client: Client): + """ + Parameters + ---------- + client : Client + The client they passed in + """ + self._client = client + + def client(self) -> Client: + """The dask client""" + return self._client + + def close(self) -> None: + """Close the dask client""" + # We do nothing, it's user provided + pass + + def __repr__(self) -> str: + return "UserDask(...)" + + +class LocalDask(Dask): + def __init__(self, n_jobs: int | None = None) -> None: + self.n_jobs = n_jobs + self._client: Client | None = None + self._cluster: LocalCluster | None = None + + def client(self) -> Client: + """Creates a usable dask client or returns an existing one + + If there is not current client, because it has been closed, create + a new one. + * If ``n_jobs == 1``, create a ``SingleThreadedClient`` + * Else create a ``Client`` with a ``LocalCluster`` + """ + if self._client is not None: + return self._client + + if self.n_jobs == 1: + cluster = None + client = SingleThreadedClient() + else: + cluster = LocalCluster( + n_workers=self.n_jobs, + processes=False, + threads_per_worker=1, + # We use tmpdir to save the workers as deleting workers takes + # more time than deleting backend directories. + # This prevent an error saying that the worker file was deleted, + # so the client could not close the worker properly + local_directory=tempfile.gettempdir(), + # Memory is handled by the pynisher, not by the dask worker/nanny + memory_limit=0, + ) + client = Client(cluster, heartbeat_interval=10000) # 10s + + self._client = client + self._cluster = cluster + return self._client + + def close(self) -> None: + """Closes any open dask client""" + if self._client is None: + return + + self._client.close() + if self._cluster is not None: + self._cluster.close() + + self._client = None + self._cluster = None + + def __repr__(self) -> str: + return f"LocalDask(n_jobs = {self.n_jobs})" diff --git a/test/test_automl/test_construction.py b/test/test_automl/test_construction.py index 5b68d35118..be6fe0e39b 100644 --- a/test/test_automl/test_construction.py +++ b/test/test_automl/test_construction.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Optional, Union from autosklearn.automl import AutoML +from autosklearn.util.dask import LocalDask from autosklearn.util.data import default_dataset_compression_arg from autosklearn.util.single_thread_client import SingleThreadedClient @@ -87,4 +88,7 @@ def test_single_job_and_no_dask_client_sets_correct_multiprocessing_context() -> assert automl._multiprocessing_context == "fork" assert automl._n_jobs == 1 - assert isinstance(automl._dask_client, SingleThreadedClient) + assert isinstance(automl._dask, LocalDask) + + with automl._dask as client: + assert isinstance(client, SingleThreadedClient) diff --git a/test/test_estimators/test_estimators.py b/test/test_estimators/test_estimators.py index d0d3f28bdb..e1e33d684a 100644 --- a/test/test_estimators/test_estimators.py +++ b/test/test_estimators/test_estimators.py @@ -140,8 +140,6 @@ def __call__(self, *args, **kwargs): assert count_succeses(automl.cv_results_) > 0 assert includes_train_scores(automl.performance_over_time_.columns) is True assert performance_over_time_is_plausible(automl.performance_over_time_) is True - # For travis-ci it is important that the client no longer exists - assert automl.automl_._dask_client is None def test_feat_type_wrong_arguments(): diff --git a/test/test_util/test_dask.py b/test/test_util/test_dask.py new file mode 100644 index 0000000000..1dbc290500 --- /dev/null +++ b/test/test_util/test_dask.py @@ -0,0 +1,75 @@ +from pathlib import Path + +from dask.distributed import Client, LocalCluster + +from autosklearn.util.dask import LocalDask, UserDask + +import pytest + + +@pytest.mark.parametrize("n_jobs", [1, 2]) +def test_user_dask(tmp_path: Path, n_jobs: int) -> None: + """ + Expects + ------- + * A UserDask should not close the client after exiting context + """ + cluster = LocalCluster( + n_workers=n_jobs, + processes=False, + threads_per_worker=1, + local_directory=tmp_path, + ) + client = Client(cluster, heartbeat_interval=10000) + + # Active at creation + dask = UserDask(client) + + client_1 = None + with dask as user_client: + client_1 = user_client + assert user_client.status == "running" + + client_2 = None + with dask as user_client: + assert user_client.status == "running" + client_2 = user_client + + # Make sure they are the same client + assert id(client_1) == id(client_2) + + # Remains running after context + assert client_1.status == "running" + + cluster.close() + client.close() + + assert client.status == "closed" + + +def test_local_dask_creates_new_clients(tmp_path: Path) -> None: + """ + Expects + ------- + * A LocalDask should create new dask clusters at each context usage + """ + # We need 2 to use an actual dask client and not a SingleThreadedClient + local_dask = LocalDask(n_jobs=2) + + client_1 = None + with local_dask as client: + client_1 = client + assert client_1.status == "running" + + assert client_1.status == "closed" + + client_2 = None + with local_dask as client: + client_2 = client + assert client_2.status == "running" + + # Make sure they were different clients + assert id(client_1) != id(client_2) + + assert client_2.status == "closed" + assert client_1.status == "closed"