Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable Mypy in evaluation (except Train Evaluator) #1077

Merged
merged 5 commits into from
Feb 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ repos:
args: [--show-error-codes]
name: mypy auto-sklearn-util
files: autosklearn/util
- id: mypy
args: [--show-error-codes]
name: mypy auto-sklearn-evaluation
files: autosklearn/evaluation
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.3
hooks:
Expand Down
68 changes: 50 additions & 18 deletions autosklearn/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
from queue import Empty
import time
import traceback
from typing import Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast

from ConfigSpace import Configuration
import numpy as np
import pynisher
from smac.runhistory.runhistory import RunInfo, RunValue
from smac.stats.stats import Stats
from smac.tae import StatusType, TAEAbortException
from smac.tae.execute_func import AbstractTAFunc

Expand All @@ -23,11 +24,17 @@
import autosklearn.evaluation.train_evaluator
import autosklearn.evaluation.test_evaluator
import autosklearn.evaluation.util
from autosklearn.util.logging_ import get_named_client_logger
from autosklearn.evaluation.train_evaluator import TYPE_ADDITIONAL_INFO
from autosklearn.util.backend import Backend
from autosklearn.util.logging_ import PickableLoggerAdapter, get_named_client_logger
from autosklearn.util.parallel import preload_modules


def fit_predict_try_except_decorator(ta, queue, cost_for_crash, **kwargs):
def fit_predict_try_except_decorator(
ta: Callable,
queue: multiprocessing.Queue,
cost_for_crash: float,
**kwargs: Any) -> None:

try:
return ta(queue=queue, **kwargs)
Expand Down Expand Up @@ -66,7 +73,7 @@ def fit_predict_try_except_decorator(ta, queue, cost_for_crash, **kwargs):
queue.close()


def get_cost_of_crash(metric):
def get_cost_of_crash(metric: Scorer) -> float:

# The metric must always be defined to extract optimum/worst
if not isinstance(metric, Scorer):
Expand All @@ -85,8 +92,11 @@ def get_cost_of_crash(metric):
return worst_possible_result


def _encode_exit_status(exit_status):
def _encode_exit_status(exit_status: Union[str, int, Type[BaseException]]
) -> Union[str, int]:
try:
# If it can be dumped, then it is int
exit_status = cast(int, exit_status)
json.dumps(exit_status)
return exit_status
except (TypeError, OverflowError):
Expand All @@ -97,13 +107,31 @@ def _encode_exit_status(exit_status):
# easier debugging of potential crashes
class ExecuteTaFuncWithQueue(AbstractTAFunc):

def __init__(self, backend, autosklearn_seed, resampling_strategy, metric,
cost_for_crash, abort_on_first_run_crash, port, pynisher_context,
initial_num_run=1, stats=None,
run_obj='quality', par_factor=1, scoring_functions=None,
output_y_hat_optimization=True, include=None, exclude=None,
memory_limit=None, disable_file_output=False, init_params=None,
budget_type=None, ta=False, **resampling_strategy_args):
def __init__(
self,
backend: Backend,
autosklearn_seed: int,
resampling_strategy: Union[str, BaseCrossValidator, _RepeatedSplits, BaseShuffleSplit],
metric: Scorer,
cost_for_crash: float,
abort_on_first_run_crash: bool,
port: int,
pynisher_context: str,
initial_num_run: int = 1,
stats: Optional[Stats] = None,
run_obj: str = 'quality',
par_factor: int = 1,
scoring_functions: Optional[List[Scorer]] = None,
output_y_hat_optimization: bool = True,
include: Optional[List[str]] = None,
exclude: Optional[List[str]] = None,
memory_limit: Optional[int] = None,
disable_file_output: bool = False,
init_params: Optional[Dict[str, Any]] = None,
budget_type: Optional[str] = None,
ta: Optional[Callable] = None,
**resampling_strategy_args: Any,
):

if resampling_strategy == 'holdout':
eval_function = autosklearn.evaluation.train_evaluator.eval_holdout
Expand Down Expand Up @@ -180,7 +208,7 @@ def __init__(self, backend, autosklearn_seed, resampling_strategy, metric,
self.port = port
self.pynisher_context = pynisher_context
if self.port is None:
self.logger = logging.getLogger("TAE")
self.logger: Union[logging.Logger, PickableLoggerAdapter] = logging.getLogger("TAE")
else:
self.logger = get_named_client_logger(
name="TAE",
Expand Down Expand Up @@ -261,6 +289,10 @@ def run(
instance_specific: Optional[str] = None,
) -> Tuple[StatusType, float, float, Dict[str, Union[int, float, str, Dict, List, Tuple]]]:

# Additional information of each of the tae executions
# Defined upfront for mypy
additional_run_info: TYPE_ADDITIONAL_INFO = {}

context = multiprocessing.get_context(self.pynisher_context)
preload_modules(context)
queue = context.Queue()
Expand All @@ -272,7 +304,7 @@ def run(
init_params.update(self.init_params)

if self.port is None:
logger = logging.getLogger("pynisher")
logger: Union[logging.Logger, PickableLoggerAdapter] = logging.getLogger("pynisher")
else:
logger = get_named_client_logger(
name="pynisher",
Expand Down Expand Up @@ -320,11 +352,11 @@ def run(
except Exception as e:
exception_traceback = traceback.format_exc()
error_message = repr(e)
additional_info = {
additional_run_info.update({
'traceback': exception_traceback,
'error': error_message
}
return StatusType.CRASHED, self.cost_for_crash, 0.0, additional_info
})
return StatusType.CRASHED, self.worst_possible_result, 0.0, additional_run_info

if obj.exit_status in (pynisher.TimeoutException, pynisher.MemorylimitException):
# Even if the pynisher thinks that a timeout or memout occured,
Expand Down Expand Up @@ -359,7 +391,7 @@ def run(
elif obj.exit_status is pynisher.MemorylimitException:
status = StatusType.MEMOUT
additional_run_info = {
'error': 'Memout (used more than %d MB).' % self.memory_limit
"error": "Memout (used more than {} MB).".format(self.memory_limit)
}
else:
raise ValueError(obj.exit_status)
Expand Down
Loading