Skip to content

Commit

Permalink
fix-1532-_ERROR_-asyncio.exceptions.CancelledError (#1540)
Browse files Browse the repository at this point in the history
* Create PR

* Abstract out dask client types

* Fix _ issue

* Extend scope of dask_client in automl.py

* Add docstring to dask module

* Indent result addition

* Add basic tests for Dask wrappers
  • Loading branch information
eddiebergman authored Jul 16, 2022
1 parent 12fe449 commit af9d469
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 160 deletions.
282 changes: 125 additions & 157 deletions autosklearn/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import os
import platform
import sys
import tempfile
import time
import types
import uuid
Expand All @@ -37,7 +36,7 @@
import sklearn.utils
from ConfigSpace.configuration_space import Configuration, ConfigurationSpace
from ConfigSpace.read_and_write import json as cs_json
from dask.distributed import Client, LocalCluster
from dask.distributed import Client
from scipy.sparse import spmatrix
from sklearn.base import BaseEstimator
from sklearn.dummy import DummyClassifier, DummyRegressor
Expand Down Expand Up @@ -105,6 +104,7 @@
from autosklearn.pipeline.components.regression import RegressorChoice
from autosklearn.smbo import AutoMLSMBO
from autosklearn.util import RE_PATTERN, pipeline
from autosklearn.util.dask import Dask, LocalDask, UserDask
from autosklearn.util.data import (
DatasetCompressionSpec,
default_dataset_compression_arg,
Expand All @@ -120,7 +120,6 @@
warnings_to,
)
from autosklearn.util.parallel import preload_modules
from autosklearn.util.single_thread_client import SingleThreadedClient
from autosklearn.util.smac_wrap import SMACCallback, SmacRunCallback
from autosklearn.util.stopwatch import StopWatch

Expand Down Expand Up @@ -299,21 +298,22 @@ def __init__(
self._initial_configurations_via_metalearning = (
initial_configurations_via_metalearning
)
self._n_jobs = n_jobs

self._scoring_functions = scoring_functions or []
self._resampling_strategy_arguments = resampling_strategy_arguments or {}
self._multiprocessing_context = "forkserver"

# Single core, local runs should use fork to prevent the __main__ requirements
# in examples. Nevertheless, multi-process runs have spawn as requirement to
# reduce the possibility of a deadlock
if n_jobs == 1 and dask_client is None:
self._multiprocessing_context = "fork"
self._dask_client = SingleThreadedClient()
self._n_jobs = 1
self._dask: Dask
if dask_client is not None:
self._dask = UserDask(client=dask_client)
else:
self._multiprocessing_context = "forkserver"
self._dask_client = dask_client
self._n_jobs = n_jobs
self._dask = LocalDask(n_jobs=n_jobs)
if n_jobs == 1:
self._multiprocessing_context = "fork"

# Create the backend
self._backend: Backend = create(
Expand Down Expand Up @@ -350,38 +350,6 @@ def __init__(
self.num_run = 0
self.fitted = False

def _create_dask_client(self) -> None:
self._is_dask_client_internally_created = True
self._dask_client = Client(
LocalCluster(
n_workers=self._n_jobs,
processes=False,
threads_per_worker=1,
# We use the temporal directory to save the
# dask workers, because deleting workers takes
# more time than deleting backend directories
# This prevent an error saying that the worker
# file was deleted, so the client could not close
# the worker properly
local_directory=tempfile.gettempdir(),
# Memory is handled by the pynisher, not by the dask worker/nanny
memory_limit=0,
),
# Heartbeat every 10s
heartbeat_interval=10000,
)

def _close_dask_client(self, force: bool = False) -> None:
if getattr(self, "_dask_client", None) is not None and (
force or getattr(self, "_is_dask_client_internally_created", False)
):
self._dask_client.shutdown()
self._dask_client.close()
del self._dask_client
self._dask_client = None
self._is_dask_client_internally_created = False
del self._is_dask_client_internally_created

def _get_logger(self, name: str) -> PicklableClientLogger:
logger_name = "AutoML(%d):%s" % (self._seed, name)

Expand Down Expand Up @@ -747,17 +715,6 @@ def fit(
"autosklearn.metrics.Scorer."
)

# If no dask client was provided, we create one, so that we can
# start a ensemble process in parallel to smbo optimize
if self._dask_client is None and (
self._ensemble_class is not None
or self._n_jobs is not None
and self._n_jobs > 1
):
self._create_dask_client()
else:
self._is_dask_client_internally_created = False

self._dataset_name = dataset_name
self._stopwatch.start(self._dataset_name)

Expand Down Expand Up @@ -902,70 +859,85 @@ def fit(
)

n_meta_configs = self._initial_configurations_via_metalearning
_proc_smac = AutoMLSMBO(
config_space=self.configuration_space,
dataset_name=self._dataset_name,
backend=self._backend,
total_walltime_limit=time_left,
func_eval_time_limit=per_run_time_limit,
memory_limit=self._memory_limit,
data_memory_limit=self._data_memory_limit,
stopwatch=self._stopwatch,
n_jobs=self._n_jobs,
dask_client=self._dask_client,
start_num_run=self.num_run,
num_metalearning_cfgs=n_meta_configs,
config_file=configspace_path,
seed=self._seed,
metadata_directory=self._metadata_directory,
metrics=self._metrics,
resampling_strategy=self._resampling_strategy,
resampling_strategy_args=self._resampling_strategy_arguments,
include=self._include,
exclude=self._exclude,
disable_file_output=self._disable_evaluator_output,
get_smac_object_callback=self._get_smac_object_callback,
smac_scenario_args=self._smac_scenario_args,
scoring_functions=self._scoring_functions,
port=self._logger_port,
pynisher_context=self._multiprocessing_context,
ensemble_callback=proc_ensemble,
trials_callback=self._get_trials_callback,
)
with self._dask as dask_client:
resamp_args = self._resampling_strategy_arguments
_proc_smac = AutoMLSMBO(
config_space=self.configuration_space,
dataset_name=self._dataset_name,
backend=self._backend,
total_walltime_limit=time_left,
func_eval_time_limit=per_run_time_limit,
memory_limit=self._memory_limit,
data_memory_limit=self._data_memory_limit,
stopwatch=self._stopwatch,
n_jobs=self._n_jobs,
dask_client=dask_client,
start_num_run=self.num_run,
num_metalearning_cfgs=n_meta_configs,
config_file=configspace_path,
seed=self._seed,
metadata_directory=self._metadata_directory,
metrics=self._metrics,
resampling_strategy=self._resampling_strategy,
resampling_strategy_args=resamp_args,
include=self._include,
exclude=self._exclude,
disable_file_output=self._disable_evaluator_output,
get_smac_object_callback=self._get_smac_object_callback,
smac_scenario_args=self._smac_scenario_args,
scoring_functions=self._scoring_functions,
port=self._logger_port,
pynisher_context=self._multiprocessing_context,
ensemble_callback=proc_ensemble,
trials_callback=self._get_trials_callback,
)

(
self.runhistory_,
self.trajectory_,
self._budget_type,
) = _proc_smac.run_smbo()
trajectory_filename = os.path.join(
self._backend.get_smac_output_directory_for_run(self._seed),
"trajectory.json",
)
saveable_trajectory = [
list(entry[:2]) + [entry[2].get_dictionary()] + list(entry[3:])
for entry in self.trajectory_
]
with open(trajectory_filename, "w") as fh:
json.dump(saveable_trajectory, fh)

self._logger.info("Starting shutdown...")
# Wait until the ensemble process is finished to avoid shutting down
# while the ensemble builder tries to access the data
if proc_ensemble is not None:
self.ensemble_performance_history = list(proc_ensemble.history)

if len(proc_ensemble.futures) > 0:
# Now we wait for the future to return as it cannot be cancelled
# while it is running: https://stackoverflow.com/a/49203129
self._logger.info(
"Ensemble script still running, waiting for it to finish."
)
result = proc_ensemble.futures.pop().result()
if result:
ensemble_history, _ = result
self.ensemble_performance_history.extend(ensemble_history)
self._logger.info("Ensemble script finished, continue shutdown.")
(
self.runhistory_,
self.trajectory_,
self._budget_type,
) = _proc_smac.run_smbo()

trajectory_filename = os.path.join(
self._backend.get_smac_output_directory_for_run(self._seed),
"trajectory.json",
)
saveable_trajectory = [
list(entry[:2])
+ [entry[2].get_dictionary()]
+ list(entry[3:])
for entry in self.trajectory_
]
with open(trajectory_filename, "w") as fh:
json.dump(saveable_trajectory, fh)

self._logger.info("Starting shutdown...")
# Wait until the ensemble process is finished to avoid shutting
# down while the ensemble builder tries to access the data
if proc_ensemble is not None:
self.ensemble_performance_history = list(
proc_ensemble.history
)

if len(proc_ensemble.futures) > 0:
# Now we wait for the future to return as it cannot be
# cancelled while it is running
# * https://stackoverflow.com/a/49203129
self._logger.info(
"Ensemble script still running,"
" waiting for it to finish."
)
result = proc_ensemble.futures.pop().result()

if result:
ensemble_history, _ = result
self.ensemble_performance_history.extend(
ensemble_history
)

self._logger.info(
"Ensemble script finished, continue shutdown."
)

# save the ensemble performance history file
if len(self.ensemble_performance_history) > 0:
Expand Down Expand Up @@ -1054,7 +1026,7 @@ def _log_fit_setup(self) -> None:
self._logger.debug(
" multiprocessing_context: %s", str(self._multiprocessing_context)
)
self._logger.debug(" dask_client: %s", str(self._dask_client))
self._logger.debug(" dask_client: %s", str(self._dask))
self._logger.debug(" precision: %s", str(self.precision))
self._logger.debug(
" disable_evaluator_output: %s", str(self._disable_evaluator_output)
Expand Down Expand Up @@ -1090,7 +1062,6 @@ def __sklearn_is_fitted__(self) -> bool:

def _fit_cleanup(self) -> None:
self._logger.info("Closing the dask infrastructure")
self._close_dask_client()
self._logger.info("Finished closing the dask infrastructure")

# Clean up the logger
Expand Down Expand Up @@ -1555,48 +1526,48 @@ def fit_ensemble(
# Make sure that input is valid
y = self.InputValidator.target_validator.transform(y)

# Create a client if needed
if self._dask_client is None:
self._create_dask_client()
else:
self._is_dask_client_internally_created = False

metrics = metrics if metrics is not None else self._metrics
if not isinstance(metrics, Sequence):
metrics = [metrics]

# Use the current thread to start the ensemble builder process
# The function ensemble_builder_process will internally create a ensemble
# builder in the provide dask client
manager = EnsembleBuilderManager(
start_time=time.time(),
time_left_for_ensembles=self._time_for_task,
backend=copy.deepcopy(self._backend),
dataset_name=dataset_name if dataset_name else self._dataset_name,
task=task if task else self._task,
metrics=metrics if metrics is not None else self._metrics,
ensemble_class=(
ensemble_class if ensemble_class is not None else self._ensemble_class
),
ensemble_kwargs=(
ensemble_kwargs
if ensemble_kwargs is not None
else self._ensemble_kwargs
),
ensemble_nbest=ensemble_nbest if ensemble_nbest else self._ensemble_nbest,
max_models_on_disc=self._max_models_on_disc,
seed=self._seed,
precision=precision if precision else self.precision,
max_iterations=1,
read_at_most=None,
memory_limit=self._memory_limit,
random_state=self._seed,
logger_port=self._logger_port,
pynisher_context=self._multiprocessing_context,
)
manager.build_ensemble(self._dask_client)
future = manager.futures.pop()
result = future.result()
with self._dask as dask_client:
manager = EnsembleBuilderManager(
start_time=time.time(),
time_left_for_ensembles=self._time_for_task,
backend=copy.deepcopy(self._backend),
dataset_name=dataset_name if dataset_name else self._dataset_name,
task=task if task else self._task,
metrics=metrics if metrics is not None else self._metrics,
ensemble_class=(
ensemble_class
if ensemble_class is not None
else self._ensemble_class
),
ensemble_kwargs=(
ensemble_kwargs
if ensemble_kwargs is not None
else self._ensemble_kwargs
),
ensemble_nbest=ensemble_nbest
if ensemble_nbest
else self._ensemble_nbest,
max_models_on_disc=self._max_models_on_disc,
seed=self._seed,
precision=precision if precision else self.precision,
max_iterations=1,
read_at_most=None,
memory_limit=self._memory_limit,
random_state=self._seed,
logger_port=self._logger_port,
pynisher_context=self._multiprocessing_context,
)
manager.build_ensemble(dask_client)
future = manager.futures.pop()
result = future.result()

if result is None:
raise ValueError(
"Error building the ensemble - please check the log file and command "
Expand All @@ -1606,7 +1577,6 @@ def fit_ensemble(
self._ensemble_class = ensemble_class

self._load_models()
self._close_dask_client()
return self

def _load_models(self):
Expand Down Expand Up @@ -2295,7 +2265,7 @@ def _create_search_space(

def __getstate__(self) -> dict[str, Any]:
# Cannot serialize a client!
self._dask_client = None
self._dask = None
self.logging_server = None
self.stop_logging_server = None
return self.__dict__
Expand All @@ -2304,8 +2274,6 @@ def __del__(self) -> None:
# Clean up the logger
self._clean_logger()

self._close_dask_client()


class AutoMLClassifier(AutoML):

Expand Down
Loading

0 comments on commit af9d469

Please sign in to comment.