From 64b8b9cb351f4503db9e99b0163e50b6df247f9f Mon Sep 17 00:00:00 2001 From: Francisco Rivera Valverde <44504424+franchuterivera@users.noreply.github.com> Date: Tue, 5 Jan 2021 09:29:42 +0100 Subject: [PATCH 01/10] Add last history to ensemble (#1046) --- autosklearn/automl.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/autosklearn/automl.py b/autosklearn/automl.py index 4bdf7900cd..ea741eb93f 100644 --- a/autosklearn/automl.py +++ b/autosklearn/automl.py @@ -761,19 +761,21 @@ def fit( if proc_ensemble is not None: self.ensemble_performance_history = list(proc_ensemble.history) - # save the ensemble performance history file - if len(self.ensemble_performance_history) > 0: - pd.DataFrame(self.ensemble_performance_history).to_json( - os.path.join(self._backend.internals_directory, 'ensemble_history.json')) - if len(proc_ensemble.futures) > 0: - future = proc_ensemble.futures.pop() # Now we need to wait for the future to return as it cannot be cancelled while it # is running: https://stackoverflow.com/a/49203129 self._logger.info("Ensemble script still running, waiting for it to finish.") - future.result() + result = proc_ensemble.futures.pop().result() + if result: + ensemble_history, _, _, _, _ = result + self.ensemble_performance_history.extend(ensemble_history) self._logger.info("Ensemble script finished, continue shutdown.") + # save the ensemble performance history file + if len(self.ensemble_performance_history) > 0: + pd.DataFrame(self.ensemble_performance_history).to_json( + os.path.join(self._backend.internals_directory, 'ensemble_history.json')) + if load_models: self._logger.info("Loading models...") self._load_models() From 4d19551a7824ec04f47d41e83ef39d1b2b8d48d6 Mon Sep 17 00:00:00 2001 From: ROHIT AGARWAL Date: Tue, 5 Jan 2021 15:00:58 +0530 Subject: [PATCH 02/10] Example multiple metric (#1045) * Adding example which depicts on how to calc multiple score per run * Update abstract_evaluator.py Co-authored-by: Rohit Agarwal --- autosklearn/evaluation/abstract_evaluator.py | 13 ++-- .../example_calc_multiple_metrics.py | 66 +++++++++++++++++++ 2 files changed, 70 insertions(+), 9 deletions(-) create mode 100644 examples/40_advanced/example_calc_multiple_metrics.py diff --git a/autosklearn/evaluation/abstract_evaluator.py b/autosklearn/evaluation/abstract_evaluator.py index 1ea707f5d7..18054ce36d 100644 --- a/autosklearn/evaluation/abstract_evaluator.py +++ b/autosklearn/evaluation/abstract_evaluator.py @@ -19,7 +19,7 @@ from autosklearn.pipeline.implementations.util import ( convert_multioutput_multiclass_to_multilabel ) -from autosklearn.metrics import calculate_score, CLASSIFICATION_METRICS, REGRESSION_METRICS +from autosklearn.metrics import calculate_score from autosklearn.util.logging_ import get_named_client_logger from ConfigSpace import Configuration @@ -264,14 +264,9 @@ def _loss(self, y_true, y_hat, scoring_functions=None): scoring_functions=scoring_functions) if hasattr(score, '__len__'): - # TODO: instead of using self.metric, it should use all metrics given by key. - # But now this throws error... - if self.task_type in CLASSIFICATION_TASKS: - err = {key: metric._optimum - score[key] for key, metric in - CLASSIFICATION_METRICS.items() if key in score} - else: - err = {key: metric._optimum - score[key] for key, metric in - REGRESSION_METRICS.items() if key in score} + err = {metric.name: metric._optimum - score[metric.name] + for metric in scoring_functions} + err[self.metric.name] = self.metric._optimum - score[self.metric.name] else: err = self.metric._optimum - score diff --git a/examples/40_advanced/example_calc_multiple_metrics.py b/examples/40_advanced/example_calc_multiple_metrics.py new file mode 100644 index 0000000000..7139d68832 --- /dev/null +++ b/examples/40_advanced/example_calc_multiple_metrics.py @@ -0,0 +1,66 @@ +# -*- encoding: utf-8 -*- +""" +======= +Metrics +======= + +In *Auto-sklearn*, model is optimized over a metric, either built-in or +custom metric. Moreover, it is also possible to calculate multiple metrics +per run. The following examples show how to calculate metrics built-in +and self-defined metrics for a classification problem. +""" + +import autosklearn.classification +import custom_metrics +import pandas as pd +import sklearn.datasets +import sklearn.metrics +from autosklearn.metrics import balanced_accuracy, precision, recall, f1 + + +def get_metric_result(cv_results): + results = pd.DataFrame.from_dict(cv_results) + results = results[results['status'] == "Success"] + cols = ['rank_test_scores', 'param_classifier:__choice__', 'mean_test_score'] + cols.extend([key for key in cv_results.keys() if key.startswith('metric_')]) + return results[cols] + + +if __name__ == "__main__": + ############################################################################ + # Data Loading + # ============ + + X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) + X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) + + ############################################################################ + # Build and fit a classifier + # ========================== + + error_rate = autosklearn.metrics.make_scorer( + name='custom_error', + score_func=custom_metrics.error, + optimum=0, + greater_is_better=False, + needs_proba=False, + needs_threshold=False + ) + cls = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=120, + per_run_time_limit=30, + scoring_functions=[balanced_accuracy, precision, recall, f1, error_rate] + ) + cls.fit(X_train, y_train, X_test, y_test) + + ########################################################################### + # Get the Score of the final ensemble + # =================================== + + predictions = cls.predict(X_test) + print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) + + print("#" * 80) + print("Metric results") + print(get_metric_result(cls.cv_results_).to_string(index=False)) From fdea5a61e155894d61940d1e4c44c86a7493ddf2 Mon Sep 17 00:00:00 2001 From: Katharina Eggensperger Date: Fri, 8 Jan 2021 17:42:14 +0100 Subject: [PATCH 03/10] Update __version__.py --- autosklearn/__version__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/autosklearn/__version__.py b/autosklearn/__version__.py index 6160ca5efb..299c1f6235 100644 --- a/autosklearn/__version__.py +++ b/autosklearn/__version__.py @@ -1,4 +1,4 @@ """Version information.""" # The following line *must* be the last in the module, exactly as formatted: -__version__ = "0.12.1" +__version__ = "0.12.2dev" From 4b851c4cc235f81bac70b39d9eb3de2a4354ad7c Mon Sep 17 00:00:00 2001 From: Katharina Eggensperger Date: Mon, 11 Jan 2021 10:22:24 +0100 Subject: [PATCH 04/10] Add Binder links for examples (#1052) * init * ADD swig * Copy notebooks for binder * trigger ghpages * RM second binder dir * ADD exception for custom metrics example * FIX notebook dir and askl being not installed * FIX comma * fix dirnames * UNDO buildingdocs for binder branch --- .binder/apt.txt | 2 ++ .binder/postBuild | 43 ++++++++++++++++++++++++++++++++++++++++ .binder/requirements.txt | 1 + doc/conf.py | 21 +++++++++++++++++++- 4 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 .binder/apt.txt create mode 100644 .binder/postBuild create mode 100644 .binder/requirements.txt diff --git a/.binder/apt.txt b/.binder/apt.txt new file mode 100644 index 0000000000..059bded08b --- /dev/null +++ b/.binder/apt.txt @@ -0,0 +1,2 @@ +build-essential +swig diff --git a/.binder/postBuild b/.binder/postBuild new file mode 100644 index 0000000000..0677d22652 --- /dev/null +++ b/.binder/postBuild @@ -0,0 +1,43 @@ +#!/bin/bash + +set -e + +python -m pip install .[docs,examples] + +# Taken from https://github.com/scikit-learn/scikit-learn/blob/22cd233e1932457947e9994285dc7fd4e93881e4/.binder/postBuild +# under BSD3 license, copyright the scikit-learn contributors + +# This script is called in a binder context. When this script is called, we are +# inside a git checkout of the automl/auto-sklearn repo. This script +# generates notebooks from the auto-sklearn python examples. + +if [[ ! -f /.dockerenv ]]; then + echo "This script was written for repo2docker and is supposed to run inside a docker container." + echo "Exiting because this script can delete data if run outside of a docker container." + exit 1 +fi + +# Copy content we need from the auto-sklearn repo +TMP_CONTENT_DIR=/tmp/auto-sklearn +mkdir -p $TMP_CONTENT_DIR +cp -r examples .binder $TMP_CONTENT_DIR +# delete everything in current directory including dot files and dot folders +find . -delete + +# Generate notebooks and remove other files from examples folder +GENERATED_NOTEBOOKS_DIR=examples +cp -r $TMP_CONTENT_DIR/examples $GENERATED_NOTEBOOKS_DIR + +find $GENERATED_NOTEBOOKS_DIR -name 'example_*.py' -exec sphx_glr_python_to_jupyter.py '{}' + +# Keep __init__.py and custom_metrics.py +NON_NOTEBOOKS=$(find $GENERATED_NOTEBOOKS_DIR -type f | grep -v '\.ipynb' | grep -v 'init' | grep -v 'custom_metrics') +rm -f $NON_NOTEBOOKS + +# Modify path to be consistent by the path given by sphinx-gallery +mkdir notebooks +mv $GENERATED_NOTEBOOKS_DIR notebooks/ + +# Put the .binder folder back (may be useful for debugging purposes) +mv $TMP_CONTENT_DIR/.binder . +# Final clean up +rm -rf $TMP_CONTENT_DIR diff --git a/.binder/requirements.txt b/.binder/requirements.txt new file mode 100644 index 0000000000..3c8d7e7822 --- /dev/null +++ b/.binder/requirements.txt @@ -0,0 +1 @@ +-r ../requirements.txt diff --git a/doc/conf.py b/doc/conf.py index 68e69ee0d1..91ed264eef 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -59,6 +59,13 @@ autosectionlabel_prefix_document = True # Sphinx-gallery configuration. + +# get current branch +binder_branch = 'master' +import autosklearn +if "dev" in autosklearn.__version__: + binder_branch = "development" + sphinx_gallery_conf = { # path to the examples 'examples_dirs': '../examples', @@ -71,7 +78,19 @@ #}, 'backreferences_dir': None, 'filename_pattern': 'example.*.py$', - 'ignore_pattern': r'custom_metrics\.py|__init__\.py' + 'ignore_pattern': r'custom_metrics\.py|__init__\.py', + 'binder': { + # Required keys + 'org': 'automl', + 'repo': 'auto-sklearn', + 'branch': binder_branch, + 'binderhub_url': 'https://mybinder.org', + 'dependencies': ['../.binder/apt.txt', '../.binder/requirements.txt'], + #'filepath_prefix': '' # A prefix to prepend to any filepaths in Binder links. + # Jupyter notebooks for Binder will be copied to this directory (relative to built documentation root). + 'notebooks_dir': 'notebooks/', + 'use_jupyter_lab': True, # Whether Binder links should start Jupyter Lab instead of the Jupyter Notebook interface. + }, } # Add any paths that contain templates here, relative to this directory. From a615c3eee289150c20a4b5ad84d27c35985973c0 Mon Sep 17 00:00:00 2001 From: Francisco Rivera Valverde <44504424+franchuterivera@users.noreply.github.com> Date: Tue, 12 Jan 2021 16:16:14 +0100 Subject: [PATCH 05/10] Add fit/transform interface to the data validation (#1041) * Initial commit for new data scheme * New input validator schema * Incorporate feedback from #1041 * Missing feedback from #1041 * Deleted missing file * Merge conflict with two loggers * Improving coverage * Inverse transform unknown handling * Test logger client for smbo error * Try to remove random smbo error * Mode debug msg capabilites for smbo * Also print root logger * Improved checking * Check Pandas transformation * More robust time checking --- .pre-commit-config.yaml | 8 +- autosklearn/automl.py | 80 ++- autosklearn/data/feature_validator.py | 463 ++++++++++++++ autosklearn/data/target_validator.py | 388 ++++++++++++ autosklearn/data/validation.py | 727 ++++----------------- autosklearn/util/hash.py | 29 - test/test_automl/test_automl.py | 56 +- test/test_automl/test_estimators.py | 21 +- test/test_data/test_feature_validator.py | 547 ++++++++++++++++ test/test_data/test_target_validator.py | 503 +++++++++++++++ test/test_data/test_validation.py | 775 ++++------------------- test/test_util/test_hash.py | 62 -- 12 files changed, 2238 insertions(+), 1421 deletions(-) create mode 100644 autosklearn/data/feature_validator.py create mode 100644 autosklearn/data/target_validator.py delete mode 100644 autosklearn/util/hash.py create mode 100644 test/test_data/test_feature_validator.py create mode 100644 test/test_data/test_target_validator.py delete mode 100644 test/test_util/test_hash.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1bc50bec0c..af23e08eff 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,15 +3,19 @@ repos: rev: v0.761 hooks: - id: mypy + args: [--show-error-codes] name: mypy auto-sklearn-ensembles files: autosklearn/ensembles - id: mypy + args: [--show-error-codes] name: mypy auto-sklearn-metrics files: autosklearn/metrics - id: mypy + args: [--show-error-codes] name: mypy auto-sklearn-data files: autosklearn/data - id: mypy + args: [--show-error-codes] name: mypy auto-sklearn-util files: autosklearn/util - repo: https://gitlab.com/pycqa/flake8 @@ -19,7 +23,7 @@ repos: hooks: - id: flake8 name: flake8 auto-sklearn - files: autosklearn/* + files: autosklearn/.* - id: flake8 name: flake8 file-order-data files: autosklearn/data @@ -42,4 +46,4 @@ repos: - flake8-import-order - id: flake8 name: flake8 autosklearn-test - files: test/* + files: test/.* diff --git a/autosklearn/automl.py b/autosklearn/automl.py index ea741eb93f..d9119beb6e 100644 --- a/autosklearn/automl.py +++ b/autosklearn/automl.py @@ -9,6 +9,7 @@ import sys import time from typing import Any, Dict, Optional, List, Union +import uuid import unittest.mock import warnings import tempfile @@ -29,8 +30,8 @@ import joblib import sklearn.utils import scipy.sparse -from sklearn.metrics._classification import type_of_target from sklearn.utils.validation import check_is_fitted +from sklearn.metrics._classification import type_of_target from sklearn.dummy import DummyClassifier, DummyRegressor from autosklearn.metrics import Scorer @@ -51,7 +52,6 @@ from autosklearn.ensemble_builder import EnsembleBuilderManager from autosklearn.ensembles.singlebest_ensemble import SingleBest from autosklearn.smbo import AutoMLSMBO -from autosklearn.util.hash import hash_array_or_matrix from autosklearn.metrics import f1_macro, accuracy, r2 from autosklearn.constants import MULTILABEL_CLASSIFICATION, MULTICLASS_CLASSIFICATION, \ REGRESSION_TASKS, REGRESSION, BINARY_CLASSIFICATION, MULTIOUTPUT_REGRESSION, \ @@ -217,7 +217,7 @@ def __init__(self, self._debug_mode = debug_mode - self.InputValidator = InputValidator() + self.InputValidator = None # type: Optional[InputValidator] # The ensemble performance history through time self.ensemble_performance_history = [] @@ -436,29 +436,34 @@ def fit( dataset_name: Optional[str] = None, only_return_configuration_space: Optional[bool] = False, load_models: bool = True, + is_classification: bool = False, ): if dataset_name is None: - dataset_name = hash_array_or_matrix(X) - # The first thing we have to do is create the logger to update the backend + dataset_name = str(uuid.uuid1(clock_seq=os.getpid())) + + # By default try to use the TCP logging port or get a new port + self._logger_port = logging.handlers.DEFAULT_TCP_LOGGING_PORT self._logger = self._get_logger(dataset_name) + + # The first thing we have to do is create the logger to update the backend self._backend.setup_logger(self._logger_port) self._backend.save_start_time(self._seed) self._stopwatch = StopWatch() - # Employ the user feature types if provided - self.InputValidator.register_user_feat_type(feat_type, X) - # Make sure that input is valid # Performs Ordinal one hot encoding to the target # both for train and test data - X, y = self.InputValidator.validate(X, y) + self.InputValidator = InputValidator( + is_classification=is_classification, + feat_type=feat_type, + logger_port=self._logger_port, + ) + self.InputValidator.fit(X_train=X, y_train=y, X_test=X_test, y_test=y_test) + X, y = self.InputValidator.transform(X, y) if X_test is not None: - X_test, y_test = self.InputValidator.validate(X_test, y_test) - if len(y.shape) != len(y_test.shape): - raise ValueError('Target value shapes do not match: %s vs %s' - % (y.shape, y_test.shape)) + X_test, y_test = self.InputValidator.transform(X_test, y_test) X, y = self.subsample_if_too_large( X=X, @@ -496,8 +501,8 @@ def fit( self._dataset_name = dataset_name self._stopwatch.start_task(self._dataset_name) - if feat_type is None and self.InputValidator.feature_types: - feat_type = self.InputValidator.feature_types + if feat_type is None and self.InputValidator.feature_validator.feat_type: + feat_type = self.InputValidator.feature_validator.feat_type # Produce debug information to the logfile self._logger.debug('Starting to print environment information') @@ -847,7 +852,10 @@ def subsample_if_too_large(X, y, logger, seed, memory_limit, task): def refit(self, X, y): # Make sure input data is valid - X, y = self.InputValidator.validate(X, y) + if self.InputValidator is None or not self.InputValidator._is_fitted: + raise ValueError("refit() is only supported after calling fit. Kindly call first " + "the estimator fit() method.") + X, y = self.InputValidator.transform(X, y) if self.models_ is None or len(self.models_) == 0 or self.ensemble_ is None: self._load_models() @@ -929,7 +937,10 @@ def predict(self, X, batch_size=None, n_jobs=1): "if 'ensemble_size != 0'") # Make sure that input is valid - X = self.InputValidator.validate_features(X) + if self.InputValidator is None or not self.InputValidator._is_fitted: + raise ValueError("predict() can only be called after performing fit(). Kindly call " + "the estimator fit() method first.") + X = self.InputValidator.feature_validator.transform(X) # Parallelize predictions across models with n_jobs processes. # Each process computes predictions in chunks of batch_size rows. @@ -987,7 +998,10 @@ def fit_ensemble(self, y, task=None, precision=32, self._logger = self._get_logger(dataset_name) # Make sure that input is valid - y = self.InputValidator.validate_target(y, is_classification=True) + if self.InputValidator is None or not self.InputValidator._is_fitted: + raise ValueError("fit_ensemble() can only be called after fit. Please call the " + "estimator fit() method prior to fit_ensemble().") + y = self.InputValidator.target_validator.transform(y) # Create a client if needed if self._dask_client is None: @@ -1110,7 +1124,10 @@ def score(self, X, y): prediction = self.predict(X) # Make sure that input is valid - X, y = self.InputValidator.validate(X, y) + if self.InputValidator is None or not self.InputValidator._is_fitted: + raise ValueError("score() is only supported after calling fit. Kindly call first " + "the estimator fit() method.") + y = self.InputValidator.target_validator.transform(y) # Encode the prediction using the input validator # We train autosklearn with a encoded version of y, @@ -1118,7 +1135,7 @@ def score(self, X, y): # Above call to validate() encodes the y given for score() # Below call encodes the prediction, so we compare in the # same representation domain - prediction = self.InputValidator.encode_target(prediction) + prediction = self.InputValidator.target_validator.transform(prediction) return calculate_score(solution=y, prediction=prediction, @@ -1366,15 +1383,7 @@ def fit( only_return_configuration_space: bool = False, load_models: bool = True, ): - - # We first validate the dtype of the target provided by the user - # In doing so, we also fit the internal encoder for classification - # In case y_test is provided, we proactively check their type - # and make sure the enconding accounts for both y_test/y_train classes - input_y = self.InputValidator.join_and_check(y, y_test) if y_test is not None else y - y_task = type_of_target( - self.InputValidator.validate_target(input_y, is_classification=True) - ) + y_task = type_of_target(y) task = self._task_mapping.get(y_task) if task is None: raise ValueError('Cannot work on data of type %s' % y_task) @@ -1394,18 +1403,22 @@ def fit( dataset_name=dataset_name, only_return_configuration_space=only_return_configuration_space, load_models=load_models, + is_classification=True, ) def predict(self, X, batch_size=None, n_jobs=1): predicted_probabilities = super().predict(X, batch_size=batch_size, n_jobs=n_jobs) - if self.InputValidator.is_single_column_target() == 1: + if self.InputValidator is None or not self.InputValidator._is_fitted: + raise ValueError("predict() is only supported after calling fit. Kindly call first " + "the estimator fit() method.") + if self.InputValidator.target_validator.is_single_column_target(): predicted_indexes = np.argmax(predicted_probabilities, axis=1) else: predicted_indexes = (predicted_probabilities > 0.5).astype(int) - return self.InputValidator.decode_target(predicted_indexes) + return self.InputValidator.target_validator.inverse_transform(predicted_indexes) def predict_proba(self, X, batch_size=None, n_jobs=1): return super().predict(X, batch_size=batch_size, n_jobs=n_jobs) @@ -1433,9 +1446,7 @@ def fit( # Check the data provided in y # After the y data type is validated, # check the task type - y_task = type_of_target( - self.InputValidator.validate_target(y) - ) + y_task = type_of_target(y) task = self._task_mapping.get(y_task) if task is None: raise ValueError('Cannot work on data of type %s' % y_task) @@ -1451,4 +1462,5 @@ def fit( dataset_name=dataset_name, only_return_configuration_space=only_return_configuration_space, load_models=load_models, + is_classification=False, ) diff --git a/autosklearn/data/feature_validator.py b/autosklearn/data/feature_validator.py new file mode 100644 index 0000000000..1e291ffada --- /dev/null +++ b/autosklearn/data/feature_validator.py @@ -0,0 +1,463 @@ +import functools +import logging +import typing + +import numpy as np + +import pandas as pd +from pandas.api.types import is_numeric_dtype + +import scipy.sparse + +import sklearn.utils +from sklearn import preprocessing +from sklearn.base import BaseEstimator +from sklearn.compose import make_column_transformer +from sklearn.exceptions import NotFittedError + +from autosklearn.util.logging_ import PickableLoggerAdapter + + +SUPPORTED_FEAT_TYPES = typing.Union[ + typing.List, + pd.DataFrame, + np.ndarray, + scipy.sparse.bsr_matrix, + scipy.sparse.coo_matrix, + scipy.sparse.csc_matrix, + scipy.sparse.csr_matrix, + scipy.sparse.dia_matrix, + scipy.sparse.dok_matrix, + scipy.sparse.lil_matrix, +] + + +class FeatureValidator(BaseEstimator): + """ + A class to pre-process features. In this regards, the format of the data is checked, + and if applicable, features are encoded + Attributes + ---------- + feat_type: typing.Optional[typing.List[str]] + In case the data is not a pandas DataFrame, this list indicates + which columns should be treated as categorical + data_type: + Class name of the data type provided during fit. + encoder: typing.Optional[BaseEstimator] + Host a encoder object if the data requires transformation (for example, + if provided a categorical column in a pandas DataFrame) + enc_columns: typing.List[str] + List of columns that where encoded + """ + def __init__(self, + feat_type: typing.Optional[typing.List[str]] = None, + logger: typing.Optional[PickableLoggerAdapter] = None, + ) -> None: + # If a dataframe was provided, we populate + # this attribute with the column types from the dataframe + # That is, this attribute contains whether autosklearn + # should treat a column as categorical or numerical + # During fit, if the user provided feat_types, the user + # constrain is honored. If not, this attribute is used. + self.feat_type = feat_type # type: typing.Optional[typing.List[str]] + + # Register types to detect unsupported data format changes + self.data_type = None # type: typing.Optional[type] + self.dtypes = [] # type: typing.List[str] + self.column_order = [] # type: typing.List[str] + + self.encoder = None # type: typing.Optional[BaseEstimator] + self.enc_columns = [] # type: typing.List[str] + + self.logger = logger if logger is not None else logging.getLogger(__name__) + + self._is_fitted = False + + def fit( + self, + X_train: SUPPORTED_FEAT_TYPES, + X_test: typing.Optional[SUPPORTED_FEAT_TYPES] = None, + ) -> BaseEstimator: + """ + Validates and fit a categorical encoder (if needed) to the features. + The supported data types are List, numpy arrays and pandas DataFrames. + CSR sparse data types are also supported + + Parameters + ---------- + X_train: SUPPORTED_FEAT_TYPES + A set of features that are going to be validated (type and dimensionality + checks) and a encoder fitted in the case the data needs encoding + X_test: typing.Optional[SUPPORTED_FEAT_TYPES] + A hold out set of data used for checking + """ + + # If a list was provided, it will be converted to pandas + if isinstance(X_train, list): + X_train, X_test = self.list_to_dataframe(X_train, X_test) + + # Register the user provided feature types + if self.feat_type is not None: + if hasattr(X_train, "iloc"): + raise ValueError("When providing a DataFrame to Auto-Sklearn, we extract " + "the feature types from the DataFrame.dtypes. That is, " + "providing the option feat_type to the fit method is not " + "supported when using a Dataframe. Please make sure that the " + "type of each column in your DataFrame is properly set. " + "More details about having the correct data type in your " + "DataFrame can be seen in " + "https://pandas.pydata.org/pandas-docs/stable/reference" + "/api/pandas.DataFrame.astype.html") + # Some checks if self.feat_type is provided + if len(self.feat_type) != np.shape(X_train)[1]: + raise ValueError('Array feat_type does not have same number of ' + 'variables as X has features. %d vs %d.' % + (len(self.feat_type), np.shape(X_train)[1])) + if not all([isinstance(f, str) for f in self.feat_type]): + raise ValueError('Array feat_type must only contain strings.') + + for ft in self.feat_type: + if ft.lower() not in ['categorical', 'numerical']: + raise ValueError('Only `Categorical` and `Numerical` are ' + 'valid feature types, you passed `%s`' % ft) + + self._check_data(X_train) + + if X_test is not None: + self._check_data(X_test) + + if np.shape(X_train)[1] != np.shape(X_test)[1]: + raise ValueError("The feature dimensionality of the train and test " + "data does not match train({}) != test({})".format( + np.shape(X_train)[1], + np.shape(X_test)[1] + )) + + # Fit on the training data + self._fit(X_train) + + self._is_fitted = True + + return self + + def _fit( + self, + X: SUPPORTED_FEAT_TYPES, + ) -> BaseEstimator: + """ + In case input data is a pandas DataFrame, this utility encodes the user provided + features (from categorical for example) to a numerical value that further stages + will be able to use + + Parameters + ---------- + X: SUPPORTED_FEAT_TYPES + A set of features that are going to be validated (type and dimensionality + checks) and a encoder fitted in the case the data needs encoding + """ + if hasattr(X, "iloc") and not scipy.sparse.issparse(X): + X = typing.cast(pd.DataFrame, X) + # Treat a column with all instances a NaN as numerical + # This will prevent doing encoding to a categorical column made completely + # out of nan values -- which will trigger a fail, as encoding is not supported + # with nan values. + # Columns that are completely made of NaN values are provided to the pipeline + # so that later stages decide how to handle them + if np.any(pd.isnull(X)): + for column in X.columns: + if X[column].isna().all(): + X[column] = pd.to_numeric(X[column]) + + self.enc_columns, self.feat_type = self._get_columns_to_encode(X) + + if len(self.enc_columns) > 0: + + self.encoder = make_column_transformer( + (preprocessing.OrdinalEncoder( + handle_unknown='use_encoded_value', + unknown_value=-1, + ), self.enc_columns), + remainder="passthrough" + ) + + # Mypy redefinition + assert self.encoder is not None + self.encoder.fit(X) + + # The column transformer reoders the feature types - we therefore need to change + # it as well + def comparator(cmp1: str, cmp2: str) -> int: + if ( + cmp1 == 'categorical' and cmp2 == 'categorical' + or cmp1 == 'numerical' and cmp2 == 'numerical' + ): + return 0 + elif cmp1 == 'categorical' and cmp2 == 'numerical': + return -1 + elif cmp1 == 'numerical' and cmp2 == 'categorical': + return 1 + else: + raise ValueError((cmp1, cmp2)) + self.feat_type = sorted( + self.feat_type, + key=functools.cmp_to_key(comparator) + ) + return self + + def transform( + self, + X: SUPPORTED_FEAT_TYPES, + ) -> np.ndarray: + """ + Validates and fit a categorical encoder (if needed) to the features. + The supported data types are List, numpy arrays and pandas DataFrames. + + Parameters + ---------- + X_train: SUPPORTED_FEAT_TYPES + A set of features, whose categorical features are going to be + transformed + + Return + ------ + np.ndarray: + The transformed array + """ + if not self._is_fitted: + raise NotFittedError("Cannot call transform on a validator that is not fitted") + + # If a list was provided, it will be converted to pandas + if isinstance(X, list): + X, _ = self.list_to_dataframe(X) + + if hasattr(X, "iloc") and not scipy.sparse.issparse(X): + X = typing.cast(pd.DataFrame, X) + if np.any(pd.isnull(X)): + for column in X.columns: + if X[column].isna().all(): + X[column] = pd.to_numeric(X[column]) + + # Check the data here so we catch problems on new test data + self._check_data(X) + + # Pandas related transformations + if hasattr(X, "iloc") and self.encoder is not None: + if np.any(pd.isnull(X)): + # After above check it means that if there is a NaN + # the whole column must be NaN + # Make sure it is numerical and let the pipeline handle it + for column in X.columns: + if X[column].isna().all(): + X[column] = pd.to_numeric(X[column]) + X = self.encoder.transform(X) + + # Sparse related transformations + # Not all sparse format support index sorting + if scipy.sparse.issparse(X) and hasattr(X, 'sort_indices'): + X.sort_indices() + + return sklearn.utils.check_array( + X, + force_all_finite=False, + accept_sparse='csr' + ) + + def _check_data( + self, + X: SUPPORTED_FEAT_TYPES, + ) -> None: + """ + Feature dimensionality and data type checks + + Parameters + ---------- + X: SUPPORTED_FEAT_TYPES + A set of features that are going to be validated (type and dimensionality + checks) and a encoder fitted in the case the data needs encoding + """ + + if not isinstance(X, (np.ndarray, pd.DataFrame)) and not scipy.sparse.issparse(X): + raise ValueError("Auto-sklearn only supports Numpy arrays, Pandas DataFrames," + " scipy sparse and Python Lists, yet, the provided input is" + " of type {}".format( + type(X) + )) + + if self.data_type is None: + self.data_type = type(X) + if self.data_type != type(X): + self.logger.warning("Auto-sklearn previously received features of type %s " + "yet the current features have type %s. Changing the dtype " + "of inputs to an estimator might cause problems" % ( + str(self.data_type), + str(type(X)), + ), + ) + + # Do not support category/string numpy data. Only numbers + if hasattr(X, "dtype"): + if not np.issubdtype(X.dtype.type, np.number): # type: ignore[union-attr] + raise ValueError( + "When providing a numpy array to Auto-sklearn, the only valid " + "dtypes are numerical ones. The provided data type {} is not supported." + "".format( + X.dtype.type, # type: ignore[union-attr] + ) + ) + + # Then for Pandas, we do not support Nan in categorical columns + if hasattr(X, "iloc"): + # If entered here, we have a pandas dataframe + X = typing.cast(pd.DataFrame, X) + + # Define the column to be encoded here as the feature validator is fitted once + # per estimator + enc_columns, _ = self._get_columns_to_encode(X) + + if len(enc_columns) > 0: + if np.any(pd.isnull( + X[enc_columns].dropna( # type: ignore[call-overload] + axis='columns', how='all') + )): + # Ignore all NaN columns, and if still a NaN + # Error out + raise ValueError("Categorical features in a dataframe cannot contain " + "missing/NaN values. The OrdinalEncoder used by " + "Auto-sklearn cannot handle this yet (due to a " + "limitation on scikit-learn being addressed via: " + "https://github.com/scikit-learn/scikit-learn/issues/17123)" + ) + column_order = [column for column in X.columns] + if len(self.column_order) > 0: + if self.column_order != column_order: + raise ValueError("Changing the column order of the features after fit() is " + "not supported. Fit() method was called with " + "{} whereas the new features have {} as type".format( + self.column_order, + column_order, + )) + else: + self.column_order = column_order + dtypes = [dtype.name for dtype in X.dtypes] + if len(self.dtypes) > 0: + if self.dtypes != dtypes: + raise ValueError("Changing the dtype of the features after fit() is " + "not supported. Fit() method was called with " + "{} whereas the new features have {} as type".format( + self.dtypes, + dtypes, + )) + else: + self.dtypes = dtypes + + def _get_columns_to_encode( + self, + X: pd.DataFrame, + ) -> typing.Tuple[typing.List[str], typing.List[str]]: + """ + Return the columns to be encoded from a pandas dataframe + + Parameters + ---------- + X: pd.DataFrame + A set of features that are going to be validated (type and dimensionality + checks) and a encoder fitted in the case the data needs encoding + Returns + ------- + enc_columns: + Columns to encode, if any + feat_type: + Type of each column numerical/categorical + """ + # Register if a column needs encoding + enc_columns = [] + + # Also, register the feature types for the estimator + feat_type = [] + + # Make sure each column is a valid type + for i, column in enumerate(X.columns): + if X[column].dtype.name in ['category', 'bool']: + + enc_columns.append(column) + feat_type.append('categorical') + # Move away from np.issubdtype as it causes + # TypeError: data type not understood in certain pandas types + elif not is_numeric_dtype(X[column]): + if X[column].dtype.name == 'object': + raise ValueError( + "Input Column {} has invalid type object. " + "Cast it to a valid dtype before using it in Auto-Sklearn. " + "Valid types are numerical, categorical or boolean. " + "You can cast it to a valid dtype using " + "pandas.Series.astype ." + "If working with string objects, the following " + "tutorial illustrates how to work with text data: " + "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format( # noqa: E501 + column, + ) + ) + elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype( + X[column].dtype + ): + raise ValueError( + "Auto-sklearn does not support time and/or date datatype as given " + "in column {}. Please convert the time information to a numerical value " + "first. One example on how to do this can be found on " + "https://stats.stackexchange.com/questions/311494/".format( + column, + ) + ) + else: + raise ValueError( + "Input Column {} has unsupported dtype {}. " + "Supported column types are categorical/bool/numerical dtypes. " + "Make sure your data is formatted in a correct way, " + "before feeding it to Auto-Sklearn.".format( + column, + X[column].dtype.name, + ) + ) + else: + feat_type.append('numerical') + return enc_columns, feat_type + + def list_to_dataframe( + self, + X_train: SUPPORTED_FEAT_TYPES, + X_test: typing.Optional[SUPPORTED_FEAT_TYPES] = None, + ) -> typing.Tuple[pd.DataFrame, typing.Optional[pd.DataFrame]]: + """ + Converts a list to a pandas DataFrame. In this process, column types are inferred. + + If test data is provided, we proactively match it to train data + + Parameters + ---------- + X_train: SUPPORTED_FEAT_TYPES + A set of features that are going to be validated (type and dimensionality + checks) and a encoder fitted in the case the data needs encoding + X_test: typing.Optional[SUPPORTED_FEAT_TYPES] + A hold out set of data used for checking + Returns + ------- + pd.DataFrame: + transformed train data from list to pandas DataFrame + pd.DataFrame: + transformed test data from list to pandas DataFrame + """ + + # If a list was provided, it will be converted to pandas + X_train = pd.DataFrame(data=X_train).infer_objects() + self.logger.warning("The provided feature types to autosklearn are of type list." + "Features have been interpreted as: {}".format( + [(col, t) for col, t in zip(X_train.columns, X_train.dtypes)] + )) + if X_test is not None: + if not isinstance(X_test, list): + self.logger.warning("Train features are a list while the provided test data" + "is {}. X_test will be casted as DataFrame.".format( + type(X_test) + )) + X_test = pd.DataFrame(data=X_test).infer_objects() + return X_train, X_test diff --git a/autosklearn/data/target_validator.py b/autosklearn/data/target_validator.py new file mode 100644 index 0000000000..05381aabdb --- /dev/null +++ b/autosklearn/data/target_validator.py @@ -0,0 +1,388 @@ +import logging +import typing + +import numpy as np + +import pandas as pd +from pandas.api.types import is_numeric_dtype + +import scipy.sparse + +import sklearn.utils +from sklearn import preprocessing +from sklearn.base import BaseEstimator +from sklearn.exceptions import NotFittedError +from sklearn.utils.multiclass import type_of_target + +from autosklearn.util.logging_ import PickableLoggerAdapter + + +SUPPORTED_TARGET_TYPES = typing.Union[ + typing.List, + pd.Series, + pd.DataFrame, + np.ndarray, + scipy.sparse.bsr_matrix, + scipy.sparse.coo_matrix, + scipy.sparse.csc_matrix, + scipy.sparse.csr_matrix, + scipy.sparse.dia_matrix, + scipy.sparse.dok_matrix, + scipy.sparse.lil_matrix, +] + + +class TargetValidator(BaseEstimator): + """ + A class to pre-process targets. It validates the data provided during fit (to make sure + it matches Sklearn expectation) as well as encoding the targets in case of classification + Attributes + ---------- + is_classification: bool + A bool that indicates if the validator should operate in classification mode. + During classification, the targets are encoded. + encoder: typing.Optional[BaseEstimator] + Host a encoder object if the data requires transformation (for example, + if provided a categorical column in a pandas DataFrame) + enc_columns: typing.List[str] + List of columns that where encoded + """ + def __init__(self, + is_classification: bool = False, + logger: typing.Optional[PickableLoggerAdapter] = None, + ) -> None: + self.is_classification = is_classification + + self.data_type = None # type: typing.Optional[type] + + self.encoder = None # type: typing.Optional[BaseEstimator] + + self.out_dimensionality = None # type: typing.Optional[int] + self.type_of_target = None # type: typing.Optional[str] + + self.logger = logger if logger is not None else logging.getLogger(__name__) + + # Store the dtype for remapping to correct type + self.dtype = None # type: typing.Optional[type] + + self._is_fitted = False + + def fit( + self, + y_train: SUPPORTED_TARGET_TYPES, + y_test: typing.Optional[SUPPORTED_TARGET_TYPES] = None, + ) -> BaseEstimator: + """ + Validates and fit a categorical encoder (if needed) to the targets + The supported data types are List, numpy arrays and pandas DataFrames. + + Parameters + ---------- + y_train: SUPPORTED_TARGET_TYPES + A set of targets set aside for training + y_test: typing.Union[SUPPORTED_TARGET_TYPES] + A hold out set of data used of the targets. It is also used to fit the + categories of the encoder. + """ + # Check that the data is valid + self._check_data(y_train) + + shape = np.shape(y_train) + if y_test is not None: + self._check_data(y_test) + + if len(shape) != len(np.shape(y_test)) or ( + len(shape) > 1 and (shape[1] != np.shape(y_test)[1])): + raise ValueError("The dimensionality of the train and test targets " + "does not match train({}) != test({})".format( + np.shape(y_train), + np.shape(y_test) + )) + if isinstance(y_train, pd.DataFrame): + y_train = typing.cast(pd.DataFrame, y_train) + y_test = typing.cast(pd.DataFrame, y_test) + if y_train.columns.tolist() != y_test.columns.tolist(): + raise ValueError( + "Train and test targets must both have the same columns, yet " + "y={} and y_test={} ".format( + y_train.columns, + y_test.columns + ) + ) + + if list(y_train.dtypes) != list(y_test.dtypes): + raise ValueError("Train and test targets must both have the same dtypes") + + if self.out_dimensionality is None: + self.out_dimensionality = 1 if len(shape) == 1 else shape[1] + else: + _n_outputs = 1 if len(shape) == 1 else shape[1] + if self.out_dimensionality != _n_outputs: + raise ValueError('Number of outputs changed from %d to %d!' % + (self.out_dimensionality, _n_outputs)) + + # Fit on the training data + self._fit(y_train, y_test) + + self._is_fitted = True + + return self + + def _fit( + self, + y_train: SUPPORTED_TARGET_TYPES, + y_test: typing.Optional[SUPPORTED_TARGET_TYPES] = None, + ) -> BaseEstimator: + """ + If dealing with classification, this utility encodes the targets. + + It does so by also using the classes from the test data, to prevent encoding + errors + + Parameters + ---------- + y_train: SUPPORTED_TARGET_TYPES + The labels of the current task. They are going to be encoded in case + of classification + y_test: typing.Optional[SUPPORTED_TARGET_TYPES] + A holdout set of labels + """ + if not self.is_classification or self.type_of_target == 'multilabel-indicator': + # Only fit an encoder for classification tasks + # Also, encoding multilabel indicator data makes the data multiclass + # Let the user employ a MultiLabelBinarizer if needed + return self + + if y_test is not None: + if hasattr(y_train, "iloc"): + y_train = pd.concat([y_train, y_test], ignore_index=True, sort=False) + elif isinstance(y_train, list): + y_train = y_train + y_test + elif isinstance(y_train, np.ndarray): + y_train = np.concatenate((y_train, y_test)) + + ndim = len(np.shape(y_train)) + if ndim == 1 or (ndim > 1 and np.shape(y_train)[1] == 1): + # The label encoder makes sure data is, and remains + # 1 dimensional + self.encoder = preprocessing.OrdinalEncoder(handle_unknown='use_encoded_value', + unknown_value=-1) + else: + # We should not reach this if statement as we check for type of targets before + raise ValueError("Multi-dimensional classification is not yet supported. " + "Encoding multidimensional data converts multiple columns " + "to a 1 dimensional encoding. Data involved = {}/{}".format( + np.shape(y_train), + self.type_of_target + )) + + # Mypy redefinition + assert self.encoder is not None + + # remove ravel warning from pandas Series + if ndim > 1: + self.encoder.fit(y_train) + else: + if hasattr(y_train, 'iloc'): + y_train = typing.cast(pd.DataFrame, y_train) + self.encoder.fit(y_train.to_numpy().reshape(-1, 1)) + else: + self.encoder.fit(np.array(y_train).reshape(-1, 1)) + + # we leave objects unchanged, so no need to store dtype in this case + if hasattr(y_train, 'dtype'): + # Series and numpy arrays are checked here + # Cast is as numpy for mypy checks + y_train = typing.cast(np.ndarray, y_train) + if is_numeric_dtype(y_train.dtype): + self.dtype = y_train.dtype + elif hasattr(y_train, 'dtypes') and is_numeric_dtype(typing.cast(pd.DataFrame, + y_train).dtypes[0]): + # This case is for pandas array with a single column + y_train = typing.cast(pd.DataFrame, y_train) + self.dtype = y_train.dtypes[0] + + return self + + def transform( + self, + y: typing.Union[SUPPORTED_TARGET_TYPES], + ) -> np.ndarray: + """ + Validates and fit a categorical encoder (if needed) to the features. + The supported data types are List, numpy arrays and pandas DataFrames. + + Parameters + ---------- + y: SUPPORTED_TARGET_TYPES + A set of targets that are going to be encoded if the current task + is classification + Return + ------ + np.ndarray: + The transformed array + """ + if not self._is_fitted: + raise NotFittedError("Cannot call transform on a validator that is not fitted") + + # Check the data here so we catch problems on new test data + self._check_data(y) + + if self.encoder is not None: + # remove ravel warning from pandas Series + shape = np.shape(y) + if len(shape) > 1: + y = self.encoder.transform(y) + else: + # The Ordinal encoder expects a 2 dimensional input. + # The targets are 1 dimensional, so reshape to match the expected shape + if hasattr(y, 'iloc'): + y = typing.cast(pd.DataFrame, y) + y = self.encoder.transform(y.to_numpy().reshape(-1, 1)).reshape(-1) + else: + y = self.encoder.transform(np.array(y).reshape(-1, 1)).reshape(-1) + + # sklearn check array will make sure we have the + # correct numerical features for the array + # Also, a numpy array will be created + y = sklearn.utils.check_array( + y, + force_all_finite=True, + accept_sparse='csr', + ensure_2d=False, + ) + + # When translating a dataframe to numpy, make sure we + # honor the ravel requirement + if y.ndim == 2 and y.shape[1] == 1: + y = np.ravel(y) + + return y + + def inverse_transform( + self, + y: SUPPORTED_TARGET_TYPES, + ) -> np.ndarray: + """ + Revert any encoding transformation done on a target array + + Parameters + ---------- + y: typing.Union[np.ndarray, pd.DataFrame, pd.Series] + Target array to be transformed back to original form before encoding + Return + ------ + np.ndarray: + The transformed array + """ + if not self._is_fitted: + raise NotFittedError("Cannot call inverse_transform on a validator that is not fitted") + + if self.encoder is None: + return y + shape = np.shape(y) + if len(shape) > 1: + y = self.encoder.inverse_transform(y) + else: + # The targets should be a flattened array, hence reshape with -1 + if hasattr(y, 'iloc'): + y = typing.cast(pd.DataFrame, y) + y = self.encoder.inverse_transform(y.to_numpy().reshape(-1, 1)).reshape(-1) + else: + y = self.encoder.inverse_transform(np.array(y).reshape(-1, 1)).reshape(-1) + + # Inverse transform returns a numpy array of type object + # This breaks certain metrics as accuracy, which makes type_of_target be unknown + # If while fit a dtype was observed, we try to honor that dtype + if self.dtype is not None: + y = y.astype(self.dtype) + return y + + def is_single_column_target(self) -> bool: + """ + Output is encoded with a single column encoding + """ + return self.out_dimensionality == 1 + + def _check_data( + self, + y: SUPPORTED_TARGET_TYPES, + ) -> None: + """ + Perform dimensionality and data type checks on the targets + + Parameters + ---------- + y: typing.Union[np.ndarray, pd.DataFrame, pd.Series] + A set of features whose dimensionality and data type is going to be checked + """ + + if not isinstance( + y, (np.ndarray, pd.DataFrame, list, pd.Series)) and not scipy.sparse.issparse(y): + raise ValueError("Auto-sklearn only supports Numpy arrays, Pandas DataFrames," + " pd.Series, sparse data and Python Lists as targets, yet, " + "the provided input is of type {}".format( + type(y) + )) + + # Sparse data muss be numerical + # Type ignore on attribute because sparse targets have a dtype + if scipy.sparse.issparse(y) and not np.issubdtype(y.dtype.type, # type: ignore[union-attr] + np.number): + raise ValueError("When providing a sparse matrix as targets, the only supported " + "values are numerical. Please consider using a dense" + " instead." + ) + + if self.data_type is None: + self.data_type = type(y) + if self.data_type != type(y): + self.logger.warning("Auto-sklearn previously received targets of type %s " + "yet the current features have type %s. Changing the dtype " + "of inputs to an estimator might cause problems" % ( + str(self.data_type), + str(type(y)), + ), + ) + + # No Nan is supported + has_nan_values = False + if hasattr(y, 'iloc'): + has_nan_values = typing.cast(pd.DataFrame, y).isnull().values.any() + if scipy.sparse.issparse(y): + y = typing.cast(scipy.sparse.spmatrix, y) + has_nan_values = not np.array_equal(y.data, y.data) + else: + # List and array like values are considered here + # np.isnan cannot work on strings, so we have to check for every element + # but NaN, are not equal to themselves: + has_nan_values = not np.array_equal(y, y) + if has_nan_values: + raise ValueError("Target values cannot contain missing/NaN values. " + "This is not supported by scikit-learn. " + ) + + # Pandas Series is not supported for multi-label indicator + # This format checks are done by type of target + try: + self.type_of_target = type_of_target(y) + except Exception as e: + raise ValueError("The provided data could not be interpreted by Sklearn. " + "While determining the type of the targets via type_of_target " + "run into exception: {}.".format(e)) + + supported_output_types = ('binary', + 'continuous', + 'continuous-multioutput', + 'multiclass', + 'multilabel-indicator', + # Notice unknown/multiclass-multioutput are not supported + # This can only happen during testing only as estimators + # should filter out unsupported types. + ) + if self.type_of_target not in supported_output_types: + raise ValueError("Provided targets are not supported by Auto-Sklearn. " + "Provided type is {} whereas supported types are {}.".format( + self.type_of_target, + supported_output_types + )) diff --git a/autosklearn/data/validation.py b/autosklearn/data/validation.py index 4c037624b3..f269c5ef8f 100644 --- a/autosklearn/data/validation.py +++ b/autosklearn/data/validation.py @@ -1,22 +1,18 @@ # -*- encoding: utf-8 -*- - -import functools -import warnings -from typing import List, Optional, Tuple, Union +import logging +import typing import numpy as np -import pandas as pd -from pandas.api.types import is_numeric_dtype - -import scipy.sparse +from sklearn.base import BaseEstimator +from sklearn.exceptions import NotFittedError -import sklearn.utils -from sklearn import preprocessing -from sklearn.compose import make_column_transformer +from autosklearn.data.feature_validator import FeatureValidator, SUPPORTED_FEAT_TYPES +from autosklearn.data.target_validator import SUPPORTED_TARGET_TYPES, TargetValidator +from autosklearn.util.logging_ import get_named_client_logger -class InputValidator: +class InputValidator(BaseEstimator): """ Makes sure the input data complies with Auto-sklearn requirements. Categorical inputs are encoded via a Label Encoder, if the input @@ -24,605 +20,128 @@ class InputValidator: This class also perform checks for data integrity and flags the user via informative errors. + Attributes + ---------- + feat_type: typing.Optional[typing.List[str]] + In case the data is not a pandas DataFrame, this list indicates + which columns should be treated as categorical + is_classification: bool + For classification task, this flag indicates that the target data + should be encoded + feature_validator: FeatureValidator + A FeatureValidator instance used to validate and encode feature columns to match + sklearn expectations on the data + target_validator: TargetValidator + A TargetValidator instance used to validate and encode (in case of classification) + the target values """ - def __init__(self) -> None: - self.valid_pd_enc_dtypes = ['category', 'bool'] - - # If a dataframe was provided, we populate - # this attribute with the column types from the dataframe - # That is, this attribute contains whether autosklearn - # should treat a column as categorical or numerical - # During fit, if the user provided feature_types, the user - # constrain is honored. If not, this attribute is used. - self.feature_types = None # type: Optional[List[str]] - - # Whereas autosklearn performed encoding on the dataframe - # We need the target encoder as a decoder mechanism - self.feature_encoder = None - self.target_encoder = None - self.enc_columns = [] # type: List[int] - - # During consecutive calls to the validator, - # track the number of outputs of the targets - # We need to make sure y_train/y_test have the - # same dimensionality - self._n_outputs = None - - # Add support to make sure that the input to - # autosklearn has consistent dtype through calls. - # That is, once fitted, changes in the input dtype - # are not allowed - self.features_type = None # type: Optional[type] - self.target_type = None # type: Optional[type] - - def register_user_feat_type(self, feat_type: Optional[List[str]], - X: Union[pd.DataFrame, np.ndarray]) -> None: - """ - Incorporate information of the feature types when processing a Numpy array. - In case feature types is provided, if using a pd.DataFrame, this utility errors - out, explaining to the user this is contradictory. - """ - if hasattr(X, "iloc") and feat_type is not None: - raise ValueError("When providing a DataFrame to Auto-Sklearn, we extract " - "the feature types from the DataFrame.dtypes. That is, " - "providing the option feat_type to the fit method is not " - "supported when using a Dataframe. Please make sure that the " - "type of each column in your DataFrame is properly set. " - "More details about having the correct data type in your " - "DataFrame can be seen in " - "https://pandas.pydata.org/pandas-docs/stable/reference" - "/api/pandas.DataFrame.astype.html") - elif feat_type is None: - # Nothing to register. No feat type is provided - # or the features are not numpy/list where this is required - return - - # Some checks if feat_type is provided - if len(feat_type) != X.shape[1]: - raise ValueError('Array feat_type does not have same number of ' - 'variables as X has features. %d vs %d.' % - (len(feat_type), X.shape[1])) - if not all([isinstance(f, str) for f in feat_type]): - raise ValueError('Array feat_type must only contain strings.') - - for ft in feat_type: - if ft.lower() not in ['categorical', 'numerical']: - raise ValueError('Only `Categorical` and `Numerical` are ' - 'valid feature types, you passed `%s`' % ft) - - # Here we register proactively the feature types for - # Processing Numpy arrays - self.feature_types = feat_type - - def validate( + def __init__( self, - X: Union[pd.DataFrame, np.ndarray], - y: Union[pd.DataFrame, np.ndarray], + feat_type: typing.Optional[typing.List[str]] = None, is_classification: bool = False, - ) -> Tuple[np.ndarray, np.ndarray]: - """ - Wrapper for feature/targets validation - - Makes sure consistent number of samples within target and - features. - """ - - X = self.validate_features(X) - y = self.validate_target(y, is_classification) - - if X.shape[0] != y.shape[0]: - raise ValueError( - "The number of samples from the features X={} should match " - "the number of samples from the target y={}".format( - X.shape[0], - y.shape[0] - ) - ) - return X, y - - def validate_features( - self, - X: Union[pd.DataFrame, np.ndarray], - ) -> np.ndarray: - """ - Wrapper around sklearn check_array. Translates a pandas - Dataframe to a valid input for sklearn. - """ - - # Make sure that once fitted, we don't allow new dtypes - if self.features_type is None: - self.features_type = type(X) - if self.features_type != type(X): - raise ValueError("Auto-sklearn previously received features of type {} " - "yet the current features have type {}. Changing the dtype " - "of inputs to an estimator is not supported.".format( - self.features_type, - type(X) - ) - ) - - # Do not support category/string numpy data. Only numbers - if hasattr(X, "dtype") and not np.issubdtype(X.dtype.type, np.number): - raise ValueError( - "When providing a numpy array to Auto-sklearn, the only valid " - "dtypes are numerical ones. The provided data type {} is not supported." - "".format( - X.dtype.type, - ) - ) - - # Pre-process dataframe to make them numerical - # Also, encode numpy categorical objects - if hasattr(X, "iloc") and not scipy.sparse.issparse(X): - # Pandas validation provide extra user information - X = self._check_and_encode_features(X) - - if scipy.sparse.issparse(X): - X.sort_indices() - - # sklearn check array will make sure we have the - # correct numerical features for the array - # Also, a numpy array will be created - X = sklearn.utils.check_array( - X, - force_all_finite=False, - accept_sparse='csr' - ) - return X - - def validate_target( - self, - y: Union[pd.DataFrame, np.ndarray], - is_classification: bool = False, - ) -> np.ndarray: - """ - Wrapper around sklearn check_array. Translates a pandas - Dataframe to a valid input for sklearn. - """ - - # Make sure that once fitted, we don't allow new dtypes - if self.target_type is None: - self.target_type = type(y) - if self.target_type != type(y): - raise ValueError("Auto-sklearn previously received targets of type {} " - "yet the current target has type {}. Changing the dtype " - "of inputs to an estimator is not supported.".format( - self.target_type, - type(y) - ) - ) - - # Target data as sparse is not supported - if scipy.sparse.issparse(y): - raise ValueError("Unsupported target data provided" - "Input targets to auto-sklearn must not be of " - "type sparse. Please convert the target input (y) " - "to a dense array via scipy.sparse.csr_matrix.todense(). " - ) - - # No Nan is supported - if np.any(pd.isnull(y)): - raise ValueError("Target values cannot contain missing/NaN values. " - "This is not supported by scikit-learn. " - ) - - if not hasattr(y, "iloc"): - y = np.atleast_1d(y) - if y.ndim == 2 and y.shape[1] == 1: - warnings.warn("A column-vector y was passed when a 1d array was" - " expected. Will change shape via np.ravel().", - sklearn.utils.DataConversionWarning, stacklevel=2) - y = np.ravel(y) - - # During classification, we do ordinal encoding - # We train a common model for test and train - # If an encoder was ever done for an estimator, - # use it always - # For regression, we default to the check_array in sklearn - # learn. This handles numerical checking and object conversion - # For regression, we expect the user to provide numerical input - # Next check will catch that - if is_classification or self.target_encoder is not None: - y = self._check_and_encode_target(y) - - # In code check to make sure everything is numeric - if hasattr(y, "iloc"): - is_number = np.vectorize(lambda x: pd.api.types.is_numeric_dtype(x)) - if not np.all(is_number(y.dtypes)): - raise ValueError( - "During the target validation (y_train/y_test) an invalid" - " input was detected. " - "Input dataframe to autosklearn must only contain numerical" - " dtypes, yet it has: {} dtypes.".format( - y.dtypes - ) - ) - elif not np.issubdtype(y.dtype, np.number): - raise ValueError( - "During the target validation (y_train/y_test) an invalid" - " input was detected. " - "Input to autosklearn must have a numerical dtype, yet it is: {}".format( - y.dtype - ) - ) - - # sklearn check array will make sure we have the - # correct numerical features for the array - # Also, a numpy array will be created - y = sklearn.utils.check_array( - y, - force_all_finite=True, - accept_sparse='csr', - ensure_2d=False, - ) - - # When translating a dataframe to numpy, make sure we - # honor the ravel requirement - if y.ndim == 2 and y.shape[1] == 1: - y = np.ravel(y) - - if self._n_outputs is None: - self._n_outputs = 1 if len(y.shape) == 1 else y.shape[1] - else: - _n_outputs = 1 if len(y.shape) == 1 else y.shape[1] - if self._n_outputs != _n_outputs: - raise ValueError('Number of outputs changed from %d to %d!' % - (self._n_outputs, _n_outputs)) - - return y - - def is_single_column_target(self) -> bool: - """ - Output is encoded with a single column encoding - """ - return self._n_outputs == 1 - - def _check_and_get_columns_to_encode( - self, - X: pd.DataFrame, - ) -> Tuple[List[int], List[str]]: - # Register if a column needs encoding - enc_columns = [] - - # Also, register the feature types for the estimator - feature_types = [] - - # Make sure each column is a valid type - for i, column in enumerate(X.columns): - if X[column].dtype.name in self.valid_pd_enc_dtypes: - - if hasattr(X, "iloc"): - enc_columns.append(column) - else: - enc_columns.append(i) - feature_types.append('categorical') - # Move away from np.issubdtype as it causes - # TypeError: data type not understood in certain pandas types - elif not is_numeric_dtype(X[column]): - if X[column].dtype.name == 'object': - raise ValueError( - "Input Column {} has invalid type object. " - "Cast it to a valid dtype before using it in Auto-Sklearn. " - "Valid types are numerical, categorical or boolean. " - "You can cast it to a valid dtype using " - "pandas.Series.astype. " - "If working with string objects, the following " - "tutorial illustrates how to work with text data: " - "https://scikit-learn.org/stable/tutorial/text_analytics/working_with_text_data.html".format( # noqa: E501 - column, - ) - ) - elif pd.core.dtypes.common.is_datetime_or_timedelta_dtype( - X[column].dtype - ): - raise ValueError( - "Auto-sklearn does not support time and/or date datatype as given " - "in column {}. Please convert the time information to a numerical value " - "first. One example on how to do this can be found on " - "https://stats.stackexchange.com/questions/311494/".format( - column, - ) - ) - else: - raise ValueError( - "Input Column {} has unsupported dtype {}. " - "Supported column types are categorical/bool/numerical dtypes. " - "Make sure your data is formatted in a correct way, " - "before feeding it to Auto-Sklearn.".format( - column, - X[column].dtype.name, - ) - ) - else: - feature_types.append('numerical') - return enc_columns, feature_types - - def _check_and_encode_features( - self, - X: pd.DataFrame, - ) -> Union[pd.DataFrame, np.ndarray]: - """ - Interprets a Pandas - Uses .iloc as a safe way to deal with pandas object - """ - # Start with the features - enc_columns, feature_types = self._check_and_get_columns_to_encode(X) - - # If there is a Nan, we cannot encode it due to a scikit learn limitation - if len(enc_columns) > 0: - if np.any(pd.isnull(X[enc_columns].dropna(axis='columns', how='all'))): - # Ignore all NaN columns, and if still a NaN - # Error out - raise ValueError("Categorical features in a dataframe cannot contain " - "missing/NaN values. The OrdinalEncoder used by " - "Auto-sklearn cannot handle this yet (due to a " - "limitation on scikit-learn being addressed via: " - "https://github.com/scikit-learn/scikit-learn/issues/17123)" - ) - elif np.any(pd.isnull(X)): - # After above check it means that if there is a NaN - # the whole column must be NaN - # Make sure it is numerical and let the pipeline handle it - for column in X.columns: - if X[column].isna().all(): - X[column] = pd.to_numeric(X[column]) - - # Make sure we only set this once. It should not change - if not self.feature_types: - self.feature_types = feature_types - - # This proc has to handle multiple calls, for X_train - # and X_test scenarios. We have to make sure also that - # data is consistent within calls - if enc_columns: - if self.enc_columns and self.enc_columns != enc_columns: - raise ValueError( - "Changing the column-types of the input data to Auto-Sklearn is not " - "allowed. The estimator previously was fitted with categorical/boolean " - "columns {}, yet, the new input data has categorical/boolean values {}. " - "Please recreate the estimator from scratch when changing the input " - "data. ".format( - self.enc_columns, - enc_columns, - ) - ) - else: - self.enc_columns = enc_columns - - if not self.feature_encoder: - self.feature_encoder = make_column_transformer( - (preprocessing.OrdinalEncoder(), self.enc_columns), - remainder="passthrough" - ) - - # Mypy redefinition - assert self.feature_encoder is not None - self.feature_encoder.fit(X) - - # The column transformer reoders the feature types - we therefore need to change - # it as well - def comparator(cmp1, cmp2): - if ( - cmp1 == 'categorical' and cmp2 == 'categorical' - or cmp1 == 'numerical' and cmp2 == 'numerical' - ): - return 0 - elif cmp1 == 'categorical' and cmp2 == 'numerical': - return -1 - elif cmp1 == 'numerical' and cmp2 == 'categorical': - return 1 - else: - raise ValueError((cmp1, cmp2)) - self.feature_types = sorted( - self.feature_types, - key=functools.cmp_to_key(comparator) - ) - - if self.feature_encoder: - try: - X = self.feature_encoder.transform(X) - except ValueError as e: - if 'Found unknown categories' in e.args[0]: - # Make the message more informative - raise ValueError( - "During fit, the input features contained categorical values in columns" - "{}, with categories {} which were encoded by Auto-sklearn automatically." - "Nevertheless, a new input contained new categories not seen during " - "training = {}. The OrdinalEncoder used by Auto-sklearn cannot handle " - "this yet (due to a limitation on scikit-learn being addressed via:" - " https://github.com/scikit-learn/scikit-learn/issues/17123)" - "".format( - self.enc_columns, - self.feature_encoder.transformers_[0][1].categories_, - e.args[0], - ) - ) - else: - raise e - - # In code check to make sure everything is numeric - if hasattr(X, "iloc"): - is_number = np.vectorize(lambda x: pd.api.types.is_numeric_dtype(x)) - if not np.all(is_number(X.dtypes)): - raise ValueError( - "Failed to convert the input dataframe to numerical dtypes: {}".format( - X.dtypes - ) - ) - elif not np.issubdtype(X.dtype, np.number): - raise ValueError( - "Failed to convert the input array to numerical dtype: {}".format( - X.dtype - ) + logger_port: typing.Optional[int] = None, + ) -> None: + self.feat_type = feat_type + self.is_classification = is_classification + self.logger_port = logger_port + if self.logger_port is not None: + self.logger = get_named_client_logger( + name='Validation', + port=self.logger_port, ) - - return X - - def _check_and_encode_target( - self, - y: Union[pd.DataFrame, np.ndarray], - ) -> Union[pd.DataFrame, np.ndarray]: - """ - This method encodes - categorical series to a numerical equivalent. - - An ordinal encoder is used for the translation - - """ - - # Convert pd.Series to dataframe as categorical series - # lack many useful methods - if isinstance(y, pd.Series): - y = y.to_frame().reset_index(drop=True) - - if hasattr(y, "iloc"): - self._check_and_get_columns_to_encode(y) - - if not self.target_encoder: - if y.ndim == 1 or (y.ndim > 1 and y.shape[1] == 1): - # The label encoder makes sure data is, and remains - # 1 dimensional - self.target_encoder = preprocessing.LabelEncoder() - else: - self.target_encoder = make_column_transformer( - (preprocessing.OrdinalEncoder(), list(range(y.shape[1]))), - ) - - # Mypy redefinition - assert self.target_encoder is not None - - # remove ravel warning from pandas Series - if len(y.shape) > 1 and y.shape[1] == 1 and hasattr(y, "to_numpy"): - self.target_encoder.fit(y.to_numpy().ravel()) - else: - self.target_encoder.fit(y) - - try: - # remove ravel warning from pandas Series - if len(y.shape) > 1 and y.shape[1] == 1 and hasattr(y, "to_numpy"): - y = self.target_encoder.transform(y.to_numpy().ravel()) - else: - y = self.target_encoder.transform(y) - except ValueError as e: - if 'Found unknown categories' in e.args[0]: - # Make the message more informative for Ordinal - raise ValueError( - "During fit, the target array contained the categorical values {} " - "which were encoded by Auto-sklearn automatically. " - "Nevertheless, a new target set contained new categories not seen during " - "training = {}. The OrdinalEncoder used by Auto-sklearn cannot handle " - "this yet (due to a limitation on scikit-learn being addressed via:" - " https://github.com/scikit-learn/scikit-learn/issues/17123)" - "".format( - self.target_encoder.transformers_[0][1].categories_, - e.args[0], - ) - ) - elif 'contains previously unseen labels' in e.args[0]: - # Make the message more informative - raise ValueError( - "During fit, the target array contained the categorical values {} " - "which were encoded by Auto-sklearn automatically. " - "Nevertheless, a new target set contained new categories not seen during " - "training = {}. This is a limitation in scikit-learn encoders being " - "discussed in //github.com/scikit-learn/scikit-learn/issues/17123".format( - self.target_encoder.classes_, - e.args[0], - ) - ) - else: - raise e - - return y - - def encode_target( - self, - y: np.ndarray, - ) -> np.ndarray: - """ - Encodes the target if there is any encoder - """ - if self.target_encoder is None: - return y else: + self.logger = logging.getLogger('Validation') - return self.target_encoder.transform(y) + self.feature_validator = FeatureValidator(feat_type=self.feat_type, + logger=self.logger) + self.target_validator = TargetValidator(is_classification=self.is_classification, + logger=self.logger) + self._is_fitted = False - def decode_target( + def fit( self, - y: np.ndarray, - ) -> np.ndarray: - """ - If the original target features were encoded, - this method employs the inverse transform method of the encoder - to decode the original features - """ - if self.target_encoder is None: - return y - - # Handle different ndim encoder for target - if hasattr(self.target_encoder, 'inverse_transform'): - return self.target_encoder.inverse_transform(y) - else: - return self.target_encoder.named_transformers_['ordinalencoder'].inverse_transform(y) - - def join_and_check( + X_train: SUPPORTED_FEAT_TYPES, + y_train: SUPPORTED_TARGET_TYPES, + X_test: typing.Optional[SUPPORTED_FEAT_TYPES] = None, + y_test: typing.Optional[SUPPORTED_TARGET_TYPES] = None, + ) -> BaseEstimator: + """ + Validates and fit a categorical encoder (if needed) to the features, and + a encoder for targets in the case of classification. Specifically: + + For features: + + Valid data types are enforced (List, np.ndarray, pd.DataFrame, pd.Series, scipy + sparse) as well as dimensionality checks + + If the provided data is a pandas DataFrame with categorical/boolean/int columns, + such columns will be encoded using an Ordinal Encoder + For targets: + + Checks for dimensionality as well as missing values are performed. + + If performing a classification task, the data is going to be encoded + + Parameters + ---------- + X_train: SUPPORTED_FEAT_TYPES + A set of features that are going to be validated (type and dimensionality + checks). If this data contains categorical columns, an encoder is going to + be instantiated and trained with this data. + y_train: SUPPORTED_TARGET_TYPES + A set of targets that are going to be encoded if the task is for classification + X_test: typing.Optional[SUPPORTED_FEAT_TYPES] + A hold out set of features used for checking + y_test: SUPPORTED_TARGET_TYPES + A hold out set of targets used for checking. Additionally, if the current task + is a classification task, this y_test categories are also going to be used to + fit a pre-processing encoding (to prevent errors on unseen classes). + Returns + ------- + self + """ + # Check that the data is valid + if np.shape(X_train)[0] != np.shape(y_train)[0]: + raise ValueError("Inconsistent number of train datapoints for features and targets," + " {} for features and {} for targets".format( + np.shape(X_train)[0], + np.shape(y_train)[0], + )) + if X_test is not None and np.shape(X_test)[0] != np.shape(y_test)[0]: + raise ValueError("Inconsistent number of test datapoints for features and targets," + " {} for features and {} for targets".format( + np.shape(X_test)[0], + np.shape(y_test)[0], + )) + + self.feature_validator.fit(X_train, X_test) + self.target_validator.fit(y_train, y_test) + self._is_fitted = True + + return self + + def transform( self, - y: Union[np.ndarray, pd.DataFrame], - y_test: Union[np.ndarray, pd.DataFrame] - ) -> Union[np.ndarray, pd.DataFrame]: - """ - This method checks for basic input quality before - merging the inputs to Auto-sklearn for a common - encoding if needed - """ - - # We expect same type of object - if type(y) != type(y_test): - raise ValueError( - "Train and test targets must be of the same type, yet y={} and y_test={}" - "".format( - type(y), - type(y_test) - ) - ) - - if isinstance(y, pd.DataFrame): - # The have to have the same columns and types - if y.columns != y_test.columns: - raise ValueError( - "Train and test targets must both have the same columns, yet " - "y={} and y_test={} ".format( - type(y), - type(y_test) - ) - ) - - if list(y.dtypes) != list(y_test.dtypes): - raise ValueError("Train and test targets must both have the same dtypes") - - return pd.concat([y, y_test], ignore_index=True, sort=False) - elif isinstance(y, np.ndarray): - # The have to have the same columns and types - if len(y.shape) != len(y_test.shape) \ - or (len(y.shape) > 1 and (y.shape[1] != y_test.shape[1])): - raise ValueError("Train and test targets must have the same dimensionality") - - if y.dtype != y_test.dtype: - raise ValueError("Train and test targets must both have the same dtype") - - return np.concatenate((y, y_test)) - elif isinstance(y, list): - # Provide flexibility in the list. When transformed to np.ndarray - # further checks are performed downstream - return y + y_test - elif scipy.sparse.issparse(y): - # Here just return y, vstack from scipy cause ufunc 'isnan' type errors - # in multilabel sparse matrices. Since we don't encode scipy matrices, - # No functionality impact. - return y + X: SUPPORTED_FEAT_TYPES, + y: typing.Optional[SUPPORTED_TARGET_TYPES] = None, + ) -> typing.Tuple[np.ndarray, typing.Optional[np.ndarray]]: + """ + Transform the given target or features to a numpy array + + Parameters + ---------- + X: SUPPORTED_FEAT_TYPES + A set of features to transform + y: typing.Optional[SUPPORTED_TARGET_TYPES] + A set of targets to transform + + Return + ------ + np.ndarray: + The transformed features array + np.ndarray: + The transformed targets array + """ + if not self._is_fitted: + raise NotFittedError("Cannot call transform on a validator that is not fitted") + X_transformed = self.feature_validator.transform(X) + if y is not None: + return X_transformed, self.target_validator.transform(y) else: - raise ValueError("Unsupported input type y={type(y)}. Auto-Sklearn supports " - "Pandas DataFrames, numpy arrays, scipy csr or python lists. " - "Kindly cast your targets to a supported type." - ) + return X_transformed, y diff --git a/autosklearn/util/hash.py b/autosklearn/util/hash.py deleted file mode 100644 index d65abfdf18..0000000000 --- a/autosklearn/util/hash.py +++ /dev/null @@ -1,29 +0,0 @@ -import hashlib - -import numpy as np - -import scipy.sparse - - -def hash_array_or_matrix(X: np.ndarray) -> str: - m = hashlib.md5() - - if hasattr(X, "iloc"): - X = X.to_numpy() - - if scipy.sparse.issparse(X): - m.update(X.indices) - m.update(X.indptr) - m.update(X.data) - m.update(str(X.shape).encode('utf8')) - else: - if X.flags['C_CONTIGUOUS']: - m.update(X.data) - m.update(str(X.shape).encode('utf8')) - else: - X_tmp = np.ascontiguousarray(X.T) - m.update(X_tmp.data) - m.update(str(X_tmp.shape).encode('utf8')) - - hash = m.hexdigest() - return hash diff --git a/test/test_automl/test_automl.py b/test/test_automl/test_automl.py index 83a5c90645..023954a42d 100644 --- a/test/test_automl/test_automl.py +++ b/test/test_automl/test_automl.py @@ -16,6 +16,7 @@ from smac.facade.roar_facade import ROAR from autosklearn.automl import AutoML +from autosklearn.data.validation import InputValidator import autosklearn.automl from autosklearn.data.xy_data_manager import XYDataManager from autosklearn.metrics import accuracy, log_loss, balanced_accuracy @@ -123,6 +124,7 @@ def test_refit_shuffle_on_fail(backend, dask_client): ensemble_mock = unittest.mock.Mock() ensemble_mock.get_selected_model_identifiers.return_value = [(1, 1, 50.0)] auto.ensemble_ = ensemble_mock + auto.InputValidator = InputValidator() for budget_type in [None, 'iterations']: auto._budget_type = budget_type @@ -131,6 +133,7 @@ def test_refit_shuffle_on_fail(backend, dask_client): # Make sure a valid 2D array is given to automl X = np.array([1, 2, 3]).reshape(-1, 1) y = np.array([1, 2, 3]) + auto.InputValidator.fit(X, y) auto.refit(X, y) assert failing_model.fit.call_count == 3 @@ -556,17 +559,35 @@ class MyException(Exception): ) # make sure that the logfile was created - import shutil - shutil.copytree(backend.temporary_directory, '/tmp/trydebug') logger_name = 'AutoML(%d):%s' % (1, dataset_name) + logger = logging.getLogger(logger_name) logfile = os.path.join(backend.temporary_directory, logger_name + '.log') - assert os.path.exists(logfile), automl._clean_logger() - with open(logfile) as f: - assert message in f.read(), automl._clean_logger() + assert os.path.exists(logfile), print_debug_information(automl) + str(automl._clean_logger()) + + # Give some time for the error message to be printed in the + # log file + found_message = False + for incr_tolerance in range(5): + with open(logfile) as f: + lines = f.readlines() + if any(message in line for line in lines): + found_message = True + break + else: + time.sleep(incr_tolerance) # Speed up the closing after forced crash automl._clean_logger() + if not found_message: + pytest.fail("Did not find {} in the log file {} for logger {}/{}/{}".format( + message, + print_debug_information(automl), + vars(automl._logger.logger), + vars(logger), + vars(logging.getLogger()) + )) + @pytest.mark.parametrize("metric", [log_loss, balanced_accuracy]) def test_load_best_individual_model(metric, backend, dask_client): @@ -637,31 +658,6 @@ def test_fail_if_feat_type_on_pandas_input(backend, dask_client): ) -def test_fail_if_dtype_changes_automl(backend, dask_client): - """We do not support changes in the input type. - Once a estimator is fitted, it should not change data type - """ - automl = autosklearn.automl.AutoML( - backend=backend, - time_left_for_this_task=30, - per_run_time_limit=5, - metric=accuracy, - dask_client=dask_client, - ) - - X_train = pd.DataFrame({'a': [1, 1], 'c': [1, 2]}) - y_train = [1, 0] - automl.InputValidator.validate(X_train, y_train, is_classification=True) - with pytest.raises( - ValueError, - match="Auto-sklearn previously received features of type" - ): - automl.fit( - X_train.to_numpy(), y_train, - task=BINARY_CLASSIFICATION, - ) - - @pytest.mark.parametrize( 'memory_limit,task', [ diff --git a/test/test_automl/test_estimators.py b/test/test_automl/test_estimators.py index 04447e3b17..272e85e27c 100644 --- a/test/test_automl/test_estimators.py +++ b/test/test_automl/test_estimators.py @@ -16,6 +16,7 @@ import sklearn.dummy import sklearn.datasets +from autosklearn.data.validation import InputValidator import autosklearn.pipeline.util as putil from autosklearn.ensemble_builder import MODEL_FN_RE import autosklearn.estimators # noqa F401 @@ -303,10 +304,11 @@ def test_multiclass_prediction(predict_mock, backend, dask_client): backend=backend, dask_client=dask_client, ) - classifier.InputValidator.validate_target( + classifier.InputValidator = InputValidator(is_classification=True) + classifier.InputValidator.target_validator.fit( pd.DataFrame(expected_result, dtype='category'), - is_classification=True, ) + classifier.InputValidator._is_fitted = True actual_result = classifier.predict([None] * len(predicted_indexes)) @@ -321,7 +323,6 @@ def test_multilabel_prediction(predict_mock, backend, dask_client): [0.99, 0.99], [0.99, 0.99]] predicted_indexes = np.array([[1, 0], [1, 0], [0, 1], [1, 1], [1, 1]]) - expected_result = np.array([[2, 13], [2, 13], [1, 17], [2, 17], [2, 17]]) predict_mock.return_value = np.array(predicted_probabilities) @@ -331,14 +332,17 @@ def test_multilabel_prediction(predict_mock, backend, dask_client): backend=backend, dask_client=dask_client, ) - classifier.InputValidator.validate_target( - pd.DataFrame(expected_result, dtype='int64'), - is_classification=True, + classifier.InputValidator = InputValidator(is_classification=True) + classifier.InputValidator.target_validator.fit( + pd.DataFrame(predicted_indexes, dtype='int64'), ) + classifier.InputValidator._is_fitted = True + + assert classifier.InputValidator.target_validator.type_of_target == 'multilabel-indicator' actual_result = classifier.predict([None] * len(predicted_indexes)) - np.testing.assert_array_equal(expected_result, actual_result) + np.testing.assert_array_equal(predicted_indexes, actual_result) def test_can_pickle_classifier(tmp_dir, output_dir, dask_client): @@ -467,8 +471,7 @@ def test_classification_pandas_support(tmp_dir, output_dir, dask_client): # Make sure that at least better than random. # accuracy in sklearn needs valid data # It should be 0.555 as the dataset is unbalanced. - y = automl.automl_.InputValidator.encode_target(y) - prediction = automl.automl_.InputValidator.encode_target(automl.predict(X)) + prediction = automl.predict(X) assert accuracy(y, prediction) > 0.555 assert count_succeses(automl.cv_results_) > 0 diff --git a/test/test_data/test_feature_validator.py b/test/test_data/test_feature_validator.py new file mode 100644 index 0000000000..4fa5bb80de --- /dev/null +++ b/test/test_data/test_feature_validator.py @@ -0,0 +1,547 @@ +import copy +import random + +import numpy as np + +import pandas as pd + +import pytest + +from scipy import sparse + +import sklearn.datasets +import sklearn.model_selection + +from autosklearn.data.feature_validator import FeatureValidator + + +# Fixtures to be used in this class. By default all elements have 100 datapoints +@pytest.fixture +def input_data_featuretest(request): + if request.param == 'numpy_categoricalonly_nonan': + return np.random.randint(10, size=(100, 10)) + elif request.param == 'numpy_numericalonly_nonan': + return np.random.uniform(10, size=(100, 10)) + elif request.param == 'numpy_mixed_nonan': + return np.column_stack([ + np.random.uniform(10, size=(100, 3)), + np.random.randint(10, size=(100, 3)), + np.random.uniform(10, size=(100, 3)), + np.random.randint(10, size=(100, 1)), + ]) + elif request.param == 'numpy_string_nonan': + return np.array([ + ['a', 'b', 'c', 'a', 'b', 'c'], + ['a', 'b', 'd', 'r', 'b', 'c'], + ]) + elif request.param == 'numpy_categoricalonly_nan': + array = np.random.randint(10, size=(100, 10)) + array[50, 0:5] = np.nan + return array + elif request.param == 'numpy_numericalonly_nan': + array = np.random.uniform(10, size=(100, 10)) + array[50, 0:5] = np.nan + return array + elif request.param == 'numpy_mixed_nan': + array = np.column_stack([ + np.random.uniform(10, size=(100, 3)), + np.random.randint(10, size=(100, 3)), + np.random.uniform(10, size=(100, 3)), + np.random.randint(10, size=(100, 1)), + ]) + array[50, 0:5] = np.nan + return array + elif request.param == 'numpy_string_nan': + return np.array([ + ['a', 'b', 'c', 'a', 'b', 'c'], + [np.nan, 'b', 'd', 'r', 'b', 'c'], + ]) + elif request.param == 'pandas_categoricalonly_nonan': + return pd.DataFrame([ + {'A': 1, 'B': 2}, + {'A': 3, 'B': 4}, + ], dtype='category') + elif request.param == 'pandas_numericalonly_nonan': + return pd.DataFrame([ + {'A': 1, 'B': 2}, + {'A': 3, 'B': 4}, + ], dtype='float') + elif request.param == 'pandas_mixed_nonan': + frame = pd.DataFrame([ + {'A': 1, 'B': 2}, + {'A': 3, 'B': 4}, + ], dtype='category') + frame['B'] = pd.to_numeric(frame['B']) + return frame + elif request.param == 'pandas_categoricalonly_nan': + return pd.DataFrame([ + {'A': 1, 'B': 2, 'C': np.nan}, + {'A': 3, 'C': np.nan}, + ], dtype='category') + elif request.param == 'pandas_numericalonly_nan': + return pd.DataFrame([ + {'A': 1, 'B': 2, 'C': np.nan}, + {'A': 3, 'C': np.nan}, + ], dtype='float') + elif request.param == 'pandas_mixed_nan': + frame = pd.DataFrame([ + {'A': 1, 'B': 2, 'C': 8}, + {'A': 3, 'B': 4}, + ], dtype='category') + frame['B'] = pd.to_numeric(frame['B']) + return frame + elif request.param == 'pandas_string_nonan': + return pd.DataFrame([ + {'A': 1, 'B': 2}, + {'A': 3, 'B': 4}, + ], dtype='string') + elif request.param == 'list_categoricalonly_nonan': + return [ + ['a', 'b', 'c', 'd'], + ['e', 'f', 'c', 'd'], + ] + elif request.param == 'list_numericalonly_nonan': + return [ + [1, 2, 3, 4], + [5, 6, 7, 8] + ] + elif request.param == 'list_mixed_nonan': + return [ + ['a', 2, 3, 4], + ['b', 6, 7, 8] + ] + elif request.param == 'list_categoricalonly_nan': + return [ + ['a', 'b', 'c', np.nan], + ['e', 'f', 'c', 'd'], + ] + elif request.param == 'list_numericalonly_nan': + return [ + [1, 2, 3, np.nan], + [5, 6, 7, 8] + ] + elif request.param == 'list_mixed_nan': + return [ + ['a', np.nan, 3, 4], + ['b', 6, 7, 8] + ] + elif 'sparse' in request.param: + # We expect the names to be of the type sparse_csc_nonan + sparse_, type_, nan_ = request.param.split('_') + if 'nonan' in nan_: + data = np.ones(3) + else: + data = np.array([1, 2, np.nan]) + + # Then the type of sparse + row_ind = np.array([0, 1, 2]) + col_ind = np.array([1, 2, 1]) + if 'csc' in type_: + return sparse.csc_matrix((data, (row_ind, col_ind))) + elif 'csr' in type_: + return sparse.csr_matrix((data, (row_ind, col_ind))) + elif 'coo' in type_: + return sparse.coo_matrix((data, (row_ind, col_ind))) + elif 'bsr' in type_: + return sparse.bsr_matrix((data, (row_ind, col_ind))) + elif 'lil' in type_: + return sparse.lil_matrix((data)) + elif 'dok' in type_: + return sparse.dok_matrix(np.vstack((data, data, data))) + elif 'dia' in type_: + return sparse.dia_matrix(np.vstack((data, data, data))) + else: + ValueError("Unsupported indirect fixture {}".format(request.param)) + elif 'openml' in request.param: + _, openml_id = request.param.split('_') + X, y = sklearn.datasets.fetch_openml(data_id=int(openml_id), + return_X_y=True, as_frame=True) + return X + else: + ValueError("Unsupported indirect fixture {}".format(request.param)) + + +# Actual checks for the features +@pytest.mark.parametrize( + 'input_data_featuretest', + ( + 'numpy_categoricalonly_nonan', + 'numpy_numericalonly_nonan', + 'numpy_mixed_nonan', + 'numpy_categoricalonly_nan', + 'numpy_numericalonly_nan', + 'numpy_mixed_nan', + 'pandas_categoricalonly_nonan', + 'pandas_numericalonly_nonan', + 'pandas_mixed_nonan', + 'pandas_numericalonly_nan', + 'list_numericalonly_nonan', + 'list_numericalonly_nan', + 'sparse_bsr_nonan', + 'sparse_bsr_nan', + 'sparse_coo_nonan', + 'sparse_coo_nan', + 'sparse_csc_nonan', + 'sparse_csc_nan', + 'sparse_csr_nonan', + 'sparse_csr_nan', + 'sparse_dia_nonan', + 'sparse_dia_nan', + 'sparse_dok_nonan', + 'sparse_dok_nan', + 'sparse_lil_nonan', + 'sparse_lil_nan', + 'openml_40981', # Australian + ), + indirect=True +) +def test_featurevalidator_supported_types(input_data_featuretest): + validator = FeatureValidator() + validator.fit(input_data_featuretest, input_data_featuretest) + transformed_X = validator.transform(input_data_featuretest) + if sparse.issparse(input_data_featuretest): + assert sparse.issparse(transformed_X) + else: + assert isinstance(transformed_X, np.ndarray) + assert np.shape(input_data_featuretest) == np.shape(transformed_X) + assert np.issubdtype(transformed_X.dtype, np.number) + assert validator._is_fitted + + +@pytest.mark.parametrize( + 'input_data_featuretest', + ( + 'list_categoricalonly_nonan', + 'list_categoricalonly_nan', + 'list_mixed_nonan', + 'list_mixed_nan', + ), + indirect=True +) +def test_featurevalidator_unsupported_list(input_data_featuretest): + validator = FeatureValidator() + with pytest.raises(ValueError, match=r".*has invalid type object. Cast it to a valid dtype.*"): + validator.fit(input_data_featuretest) + + +@pytest.mark.parametrize( + 'input_data_featuretest', + ( + 'numpy_string_nonan', + 'numpy_string_nan', + ), + indirect=True +) +def test_featurevalidator_unsupported_numpy(input_data_featuretest): + validator = FeatureValidator() + with pytest.raises(ValueError, match=r".*When providing a numpy array.*not supported."): + validator.fit(input_data_featuretest) + + +@pytest.mark.parametrize( + 'input_data_featuretest', + ( + 'pandas_categoricalonly_nan', + 'pandas_mixed_nan', + 'openml_179', # adult workclass has NaN in columns + ), + indirect=True +) +def test_featurevalidator_unsupported_pandas(input_data_featuretest): + validator = FeatureValidator() + with pytest.raises(ValueError, match=r"Categorical features in a dataframe.*missing/NaN"): + validator.fit(input_data_featuretest) + + +@pytest.mark.parametrize( + 'input_data_featuretest', + ( + 'numpy_categoricalonly_nonan', + 'numpy_mixed_nonan', + 'numpy_categoricalonly_nan', + 'numpy_mixed_nan', + 'pandas_categoricalonly_nonan', + 'pandas_mixed_nonan', + 'list_numericalonly_nonan', + 'list_numericalonly_nan', + 'sparse_bsr_nonan', + 'sparse_bsr_nan', + 'sparse_coo_nonan', + 'sparse_coo_nan', + 'sparse_csc_nonan', + 'sparse_csc_nan', + 'sparse_csr_nonan', + 'sparse_csr_nan', + 'sparse_dia_nonan', + 'sparse_dia_nan', + 'sparse_dok_nonan', + 'sparse_dok_nan', + 'sparse_lil_nonan', + ), + indirect=True +) +def test_featurevalidator_fitontypeA_transformtypeB(input_data_featuretest): + """ + Check if we can fit in a given type (numpy) yet transform + if the user changes the type (pandas then) + + This is problematic only in the case we create an encoder + """ + validator = FeatureValidator() + validator.fit(input_data_featuretest, input_data_featuretest) + if isinstance(input_data_featuretest, pd.DataFrame): + complementary_type = input_data_featuretest.to_numpy() + elif isinstance(input_data_featuretest, np.ndarray): + complementary_type = pd.DataFrame(input_data_featuretest) + elif isinstance(input_data_featuretest, list): + complementary_type = pd.DataFrame(input_data_featuretest) + elif sparse.issparse(input_data_featuretest): + complementary_type = sparse.csr_matrix(input_data_featuretest.todense()) + else: + raise ValueError(type(input_data_featuretest)) + transformed_X = validator.transform(complementary_type) + assert np.shape(input_data_featuretest) == np.shape(transformed_X) + assert np.issubdtype(transformed_X.dtype, np.number) + assert validator._is_fitted + + +def test_featurevalidator_get_columns_to_encode(): + """ + Makes sure that encoded columns are returned by _get_columns_to_encode + whereas numerical columns are not returned + """ + validator = FeatureValidator() + + df = pd.DataFrame([ + {'int': 1, 'float': 1.0, 'category': 'one', 'bool': True}, + {'int': 2, 'float': 2.0, 'category': 'two', 'bool': False}, + ]) + + for col in df.columns: + df[col] = df[col].astype(col) + + enc_columns, feature_types = validator._get_columns_to_encode(df) + + assert enc_columns == ['category', 'bool'] + assert feature_types == ['numerical', 'numerical', 'categorical', 'categorical'] + + +def test_features_unsupported_calls_are_raised(): + """ + Makes sure we raise a proper message to the user, + when providing not supported data input or using the validator in a way that is not + expected + """ + validator = FeatureValidator() + with pytest.raises(ValueError, match=r"Auto-sklearn does not support time"): + validator.fit( + pd.DataFrame({'datetime': [pd.Timestamp('20180310')]}) + ) + with pytest.raises(ValueError, match="has invalid type object"): + validator.fit( + pd.DataFrame({'string': ['foo']}) + ) + with pytest.raises(ValueError, match=r"Auto-sklearn only supports.*yet, the provided input"): + validator.fit({'input1': 1, 'input2': 2}) + with pytest.raises(ValueError, match=r"has unsupported dtype string"): + validator.fit(pd.DataFrame([{'A': 1, 'B': 2}], dtype='string')) + with pytest.raises(ValueError, match=r"The feature dimensionality of the train and test"): + validator.fit(X_train=np.array([[1, 2, 3], [4, 5, 6]]), + X_test=np.array([[1, 2, 3, 4], [4, 5, 6, 7]]), + ) + with pytest.raises(ValueError, match=r"Cannot call transform on a validator that is not fit"): + validator.transform(np.array([[1, 2, 3], [4, 5, 6]])) + validator.feat_type = ['Numerical'] + with pytest.raises(ValueError, match=r"providing the option feat_type to the fit method is.*"): + validator.fit(pd.DataFrame([[1, 2, 3], [4, 5, 6]])) + with pytest.raises(ValueError, match=r"Array feat_type does not have same number of.*"): + validator.fit(np.array([[1, 2, 3], [4, 5, 6]])) + validator.feat_type = [1, 2, 3] + with pytest.raises(ValueError, match=r"Array feat_type must only contain strings.*"): + validator.fit(np.array([[1, 2, 3], [4, 5, 6]])) + validator.feat_type = ['1', '2', '3'] + with pytest.raises(ValueError, match=r"Only `Categorical` and `Numerical` are.*"): + validator.fit(np.array([[1, 2, 3], [4, 5, 6]])) + + +@pytest.mark.parametrize( + 'input_data_featuretest', + ( + 'numpy_numericalonly_nonan', + 'numpy_numericalonly_nan', + 'pandas_numericalonly_nonan', + 'pandas_numericalonly_nan', + 'list_numericalonly_nonan', + 'list_numericalonly_nan', + # Category in numpy is handled via feat_type + 'numpy_categoricalonly_nonan', + 'numpy_mixed_nonan', + 'numpy_categoricalonly_nan', + 'numpy_mixed_nan', + 'sparse_bsr_nonan', + 'sparse_bsr_nan', + 'sparse_coo_nonan', + 'sparse_coo_nan', + 'sparse_csc_nonan', + 'sparse_csc_nan', + 'sparse_csr_nonan', + 'sparse_csr_nan', + 'sparse_dia_nonan', + 'sparse_dia_nan', + 'sparse_dok_nonan', + 'sparse_dok_nan', + 'sparse_lil_nonan', + 'sparse_lil_nan', + ), + indirect=True +) +def test_no_encoder_created(input_data_featuretest): + """ + Makes sure that for numerical only features, no encoder is created + """ + validator = FeatureValidator() + validator.fit(input_data_featuretest) + validator.transform(input_data_featuretest) + assert validator.encoder is None + + +@pytest.mark.parametrize( + 'input_data_featuretest', + ( + 'pandas_categoricalonly_nonan', + 'pandas_mixed_nonan', + ), + indirect=True +) +def test_encoder_created(input_data_featuretest): + """ + This test ensures an encoder is created if categorical data is provided + """ + validator = FeatureValidator() + validator.fit(input_data_featuretest) + transformed_X = validator.transform(input_data_featuretest) + assert validator.encoder is not None + + # Make sure that the encoded features are actually encoded. Categorical columns are at + # the start after transformation. In our fixtures, this is also honored prior encode + enc_columns, feature_types = validator._get_columns_to_encode(input_data_featuretest) + + # At least one categorical + assert 'categorical' in validator.feat_type + + # Numerical if the original data has numerical only columns + if np.any([pd.api.types.is_numeric_dtype(input_data_featuretest[col] + ) for col in input_data_featuretest.columns]): + assert 'numerical' in validator.feat_type + for i, feat_type in enumerate(feature_types): + if 'numerical' in feat_type: + np.testing.assert_array_equal( + transformed_X[:, i], + input_data_featuretest[input_data_featuretest.columns[i]].to_numpy() + ) + elif 'categorical' in feat_type: + np.testing.assert_array_equal( + transformed_X[:, i], + # Expect always 0, 1... because we use a ordinal encoder + np.array([0, 1]) + ) + else: + raise ValueError(feat_type) + + +def test_no_new_category_after_fit(): + """ + This test makes sure that we can actually pass new categories to the estimator + without throwing an error + """ + # Then make sure we catch categorical extra categories + x = pd.DataFrame({'A': [1, 2, 3, 4], 'B': [5, 6, 7, 8]}, dtype='category') + validator = FeatureValidator() + validator.fit(x) + x['A'] = x['A'].apply(lambda x: x*x) + validator.transform(x) + + +def test_unknown_encode_value(): + x = pd.DataFrame([ + {'a': -41, 'b': -3, 'c': 'a', 'd': -987.2}, + {'a': -21, 'b': -3, 'c': 'a', 'd': -9.2}, + {'a': 0, 'b': -4, 'c': 'b', 'd': -97.2}, + {'a': -51, 'b': -3, 'c': 'a', 'd': 987.2}, + {'a': 500, 'b': -3, 'c': 'a', 'd': -92}, + ]) + x['c'] = x['c'].astype('category') + validator = FeatureValidator() + + # Make sure that this value is honored + validator.fit(x) + x['c'].cat.add_categories(['NA'], inplace=True) + x.loc[0, 'c'] = 'NA' # unknown value + x_t = validator.transform(x) + # The first row should have a -1 as we added a new categorical there + expected_row = [-1, -41, -3, -987.2] + assert expected_row == x_t[0].tolist() + + +# Actual checks for the features +@pytest.mark.parametrize( + 'openml_id', + ( + 40981, # Australian + 3, # kr-vs-kp + 1468, # cnae-9 + 40975, # car + 40984, # Segment + ), +) +@pytest.mark.parametrize('train_data_type', ('numpy', 'pandas', 'list')) +@pytest.mark.parametrize('test_data_type', ('numpy', 'pandas', 'list')) +def test_featurevalidator_new_data_after_fit(openml_id, + train_data_type, test_data_type): + + # List is currently not supported as infer_objects + # cast list objects to type objects + if train_data_type == 'list' or test_data_type == 'list': + pytest.skip() + + validator = FeatureValidator() + + if train_data_type == 'numpy': + X, y = sklearn.datasets.fetch_openml(data_id=openml_id, + return_X_y=True, as_frame=False) + elif train_data_type == 'pandas': + X, y = sklearn.datasets.fetch_openml(data_id=openml_id, + return_X_y=True, as_frame=True) + else: + X, y = sklearn.datasets.fetch_openml(data_id=openml_id, + return_X_y=True, as_frame=True) + X = X.values.tolist() + y = y.values.tolist() + + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + X, y, random_state=1) + + validator.fit(X_train) + + transformed_X = validator.transform(X_test) + + # Basic Checking + if sparse.issparse(input_data_featuretest): + assert sparse.issparse(transformed_X) + else: + assert isinstance(transformed_X, np.ndarray) + assert np.shape(X_test) == np.shape(transformed_X) + + # And then check proper error messages + if train_data_type == 'pandas': + old_dtypes = copy.deepcopy(validator.dtypes) + validator.dtypes = ['dummy' for dtype in X_train.dtypes] + with pytest.raises(ValueError, match=r"hanging the dtype of the features after fit"): + transformed_X = validator.transform(X_test) + validator.dtypes = old_dtypes + if test_data_type == 'pandas': + columns = X_test.columns.tolist() + random.shuffle(columns) + X_test = X_test[columns] + with pytest.raises(ValueError, match=r"Changing the column order of the features"): + transformed_X = validator.transform(X_test) diff --git a/test/test_data/test_target_validator.py b/test/test_data/test_target_validator.py new file mode 100644 index 0000000000..76b3cabd2d --- /dev/null +++ b/test/test_data/test_target_validator.py @@ -0,0 +1,503 @@ +import numpy as np + +import pandas as pd + +import pytest +from pandas.api.types import is_numeric_dtype + +from scipy import sparse + +import sklearn.datasets +import sklearn.model_selection +from sklearn.utils.multiclass import type_of_target + +from autosklearn.data.target_validator import TargetValidator + + +# Fixtures to be used in this class. By default all elements have 100 datapoints +@pytest.fixture +def input_data_targettest(request): + if request.param == 'series_binary': + return pd.Series([1, -1, -1, 1]) + elif request.param == 'series_multiclass': + return pd.Series([1, 0, 2]) + elif request.param == 'series_multilabel': + return pd.Series([[1, 0], [0, 1]]) + elif request.param == 'series_continuous': + return pd.Series([0.1, 0.6, 0.7]) + elif request.param == 'series_continuous-multioutput': + return pd.Series([[1.5, 2.0], [3.0, 1.6]]) + elif request.param == 'pandas_binary': + return pd.DataFrame([1, -1, -1, 1]) + elif request.param == 'pandas_multiclass': + return pd.DataFrame([1, 0, 2]) + elif request.param == 'pandas_multilabel': + return pd.DataFrame([[1, 0], [0, 1]]) + elif request.param == 'pandas_continuous': + return pd.DataFrame([0.1, 0.6, 0.7]) + elif request.param == 'pandas_continuous-multioutput': + return pd.DataFrame([[1.5, 2.0], [3.0, 1.6]]) + elif request.param == 'numpy_binary': + return np.array([1, -1, -1, 1]) + elif request.param == 'numpy_multiclass': + return np.array([1, 0, 2]) + elif request.param == 'numpy_multilabel': + return np.array([[1, 0], [0, 1]]) + elif request.param == 'numpy_continuous': + return np.array([0.1, 0.6, 0.7]) + elif request.param == 'numpy_continuous-multioutput': + return np.array([[1.5, 2.0], [3.0, 1.6]]) + elif request.param == 'list_binary': + return [1, -1, -1, 1] + elif request.param == 'list_multiclass': + return [1, 0, 2] + elif request.param == 'list_multilabel': + return [[0, 1], [1, 0]] + elif request.param == 'list_continuous': + return [0.1, 0.6, 0.7] + elif request.param == 'list_continuous-multioutput': + return [[1.5, 2.0], [3.0, 1.6]] + elif 'openml' in request.param: + _, openml_id = request.param.split('_') + X, y = sklearn.datasets.fetch_openml(data_id=int(openml_id), + return_X_y=True, as_frame=True) + if len(y.shape) > 1 and y.shape[1] > 1 and np.any(y.eq('TRUE').any(1).to_numpy()): + # This 'if' is only asserted for multi-label data + # Force the downloaded data to be interpreted as multilabel + y = y.dropna() + y.replace('FALSE', 0, inplace=True) + y.replace('TRUE', 1, inplace=True) + y = y.astype(np.int) + return y + elif 'sparse' in request.param: + # We expect the names to be of the type sparse_csc_nonan + sparse_, type_, nan_ = request.param.split('_') + if 'nonan' in nan_: + data = np.ones(3) + else: + data = np.array([1, 2, np.nan]) + + # Then the type of sparse + if 'csc' in type_: + return sparse.csc_matrix(data) + elif 'csr' in type_: + return sparse.csr_matrix(data) + elif 'coo' in type_: + return sparse.coo_matrix(data) + elif 'bsr' in type_: + return sparse.bsr_matrix(data) + elif 'lil' in type_: + return sparse.lil_matrix(data) + elif 'dok' in type_: + return sparse.dok_matrix(np.vstack((data, data, data))) + elif 'dia' in type_: + return sparse.dia_matrix(np.vstack((data, data, data))) + else: + ValueError("Unsupported indirect fixture {}".format(request.param)) + else: + ValueError("Unsupported indirect fixture {}".format(request.param)) + + +# Actual checks for the targets +@pytest.mark.parametrize( + 'input_data_targettest', + ( + 'series_binary', + 'series_multiclass', + 'series_continuous', + 'pandas_binary', + 'pandas_multiclass', + 'pandas_multilabel', + 'pandas_continuous', + 'pandas_continuous-multioutput', + 'numpy_binary', + 'numpy_multiclass', + 'numpy_multilabel', + 'numpy_continuous', + 'numpy_continuous-multioutput', + 'list_binary', + 'list_multiclass', + 'list_multilabel', + 'list_continuous', + 'list_continuous-multioutput', + 'sparse_bsr_nonan', + 'sparse_coo_nonan', + 'sparse_csc_nonan', + 'sparse_csr_nonan', + 'sparse_lil_nonan', + 'openml_204', + ), + indirect=True +) +def test_targetvalidator_supported_types_noclassification(input_data_targettest): + validator = TargetValidator(is_classification=False) + validator.fit(input_data_targettest) + transformed_y = validator.transform(input_data_targettest) + if sparse.issparse(input_data_targettest): + assert sparse.issparse(transformed_y) + else: + assert isinstance(transformed_y, np.ndarray) + epected_shape = np.shape(input_data_targettest) + if len(epected_shape) > 1 and epected_shape[1] == 1: + # The target should have (N,) dimensionality instead of (N, 1) + epected_shape = (epected_shape[0], ) + assert epected_shape == np.shape(transformed_y) + assert np.issubdtype(transformed_y.dtype, np.number) + assert validator._is_fitted + + # Because there is no classification, we do not expect a encoder + assert validator.encoder is None + + if hasattr(input_data_targettest, "iloc"): + np.testing.assert_array_equal( + np.ravel(input_data_targettest.to_numpy()), + np.ravel(transformed_y) + ) + elif sparse.issparse(input_data_targettest): + np.testing.assert_array_equal( + np.ravel(input_data_targettest.todense()), + np.ravel(transformed_y.todense()) + ) + else: + np.testing.assert_array_equal( + np.ravel(np.array(input_data_targettest)), + np.ravel(transformed_y) + ) + + +@pytest.mark.parametrize( + 'input_data_targettest', + ( + 'series_binary', + 'series_multiclass', + 'pandas_binary', + 'pandas_multiclass', + 'numpy_binary', + 'numpy_multiclass', + 'list_binary', + 'list_multiclass', + 'sparse_bsr_nonan', + 'sparse_coo_nonan', + 'sparse_csc_nonan', + 'sparse_csr_nonan', + 'sparse_lil_nonan', + 'openml_2', + ), + indirect=True +) +def test_targetvalidator_supported_types_classification(input_data_targettest): + validator = TargetValidator(is_classification=True) + validator.fit(input_data_targettest) + transformed_y = validator.transform(input_data_targettest) + if sparse.issparse(input_data_targettest): + assert sparse.issparse(transformed_y) + else: + assert isinstance(transformed_y, np.ndarray) + epected_shape = np.shape(input_data_targettest) + if len(epected_shape) > 1 and epected_shape[1] == 1: + # The target should have (N,) dimensionality instead of (N, 1) + epected_shape = (epected_shape[0], ) + assert epected_shape == np.shape(transformed_y) + assert np.issubdtype(transformed_y.dtype, np.number) + assert validator._is_fitted + + # Because there is no classification, we do not expect a encoder + if not sparse.issparse(input_data_targettest): + assert validator.encoder is not None + + # The encoding should be per column + if len(transformed_y.shape) == 1: + assert np.min(transformed_y) == 0 + assert np.max(transformed_y) == len(np.unique(transformed_y)) - 1 + else: + for col in range(transformed_y.shape[1]): + assert np.min(transformed_y[:, col]) == 0 + assert np.max(transformed_y[:, col]) == len(np.unique(transformed_y[:, col])) - 1 + + # Make sure we can perform inverse transform + y_inverse = validator.inverse_transform(transformed_y) + if hasattr(input_data_targettest, 'dtype'): + # In case of numeric, we need to make sure dtype is preserved + if is_numeric_dtype(input_data_targettest.dtype): + assert y_inverse.dtype == input_data_targettest.dtype + # Then make sure every value is properly inverse-transformed + np.testing.assert_array_equal(np.array(y_inverse), np.array(input_data_targettest)) + elif hasattr(input_data_targettest, 'dtypes'): + if is_numeric_dtype(input_data_targettest.dtypes[0]): + assert y_inverse.dtype == input_data_targettest.dtypes[0] + # Then make sure every value is properly inverse-transformed + np.testing.assert_array_equal(np.array(y_inverse), + # pandas is always (N, 1) but targets are ravel() + input_data_targettest.to_numpy().reshape(-1)) + else: + # Sparse is not encoded, mainly because the sparse data is expected + # to be numpy of numerical type -- which currently does not require encoding + np.testing.assert_array_equal( + np.ravel(input_data_targettest.todense()), + np.ravel(transformed_y.todense()) + ) + + +@pytest.mark.parametrize( + 'input_data_targettest', + ( + 'series_binary', + 'pandas_binary', + 'numpy_binary', + 'list_binary', + 'openml_1066', + ), + indirect=True +) +def test_targetvalidator_binary(input_data_targettest): + assert type_of_target(input_data_targettest) == 'binary' + validator = TargetValidator(is_classification=True) + # Test the X_test also! + validator.fit(input_data_targettest, input_data_targettest) + transformed_y = validator.transform(input_data_targettest) + assert type_of_target(transformed_y) == 'binary' + + +@pytest.mark.parametrize( + 'input_data_targettest', + ( + 'series_multiclass', + 'pandas_multiclass', + 'numpy_multiclass', + 'list_multiclass', + 'openml_54', + ), + indirect=True +) +def test_targetvalidator_multiclass(input_data_targettest): + assert type_of_target(input_data_targettest) == 'multiclass' + validator = TargetValidator(is_classification=True) + # Test the X_test also! + validator.fit(input_data_targettest, input_data_targettest) + transformed_y = validator.transform(input_data_targettest) + assert type_of_target(transformed_y) == 'multiclass' + + +@pytest.mark.parametrize( + 'input_data_targettest', + ( + 'pandas_multilabel', + 'numpy_multilabel', + 'list_multilabel', + 'openml_40594', + ), + indirect=True +) +def test_targetvalidator_multilabel(input_data_targettest): + assert type_of_target(input_data_targettest) == 'multilabel-indicator' + validator = TargetValidator(is_classification=True) + # Test the X_test also! + validator.fit(input_data_targettest, input_data_targettest) + transformed_y = validator.transform(input_data_targettest) + assert type_of_target(transformed_y) == 'multilabel-indicator' + + +@pytest.mark.parametrize( + 'input_data_targettest', + ( + 'series_continuous', + 'pandas_continuous', + 'numpy_continuous', + 'list_continuous', + 'openml_531', + ), + indirect=True +) +def test_targetvalidator_continuous(input_data_targettest): + assert type_of_target(input_data_targettest) == 'continuous' + validator = TargetValidator(is_classification=False) + # Test the X_test also! + validator.fit(input_data_targettest, input_data_targettest) + transformed_y = validator.transform(input_data_targettest) + assert type_of_target(transformed_y) == 'continuous' + + +@pytest.mark.parametrize( + 'input_data_targettest', + ( + 'pandas_continuous-multioutput', + 'numpy_continuous-multioutput', + 'list_continuous-multioutput', + 'openml_41483', + ), + indirect=True +) +def test_targetvalidator_continuous_multioutput(input_data_targettest): + assert type_of_target(input_data_targettest) == 'continuous-multioutput' + validator = TargetValidator(is_classification=False) + # Test the X_test also! + validator.fit(input_data_targettest, input_data_targettest) + transformed_y = validator.transform(input_data_targettest) + assert type_of_target(transformed_y) == 'continuous-multioutput' + + +@pytest.mark.parametrize( + 'input_data_targettest', + ( + 'series_binary', + 'pandas_binary', + 'numpy_binary', + 'list_binary', + ), + indirect=True +) +def test_targetvalidator_fitontypeA_transformtypeB(input_data_targettest): + """ + Check if we can fit in a given type (numpy) yet transform + if the user changes the type (pandas then) + + This is problematic only in the case we create an encoder + """ + validator = TargetValidator(is_classification=True) + validator.fit(input_data_targettest) + if isinstance(input_data_targettest, pd.DataFrame): + complementary_type = input_data_targettest.to_numpy() + elif isinstance(input_data_targettest, pd.Series): + complementary_type = pd.DataFrame(input_data_targettest) + elif isinstance(input_data_targettest, np.ndarray): + complementary_type = pd.DataFrame(input_data_targettest) + elif isinstance(input_data_targettest, list): + complementary_type = pd.DataFrame(input_data_targettest) + validator.transform(complementary_type) + + +@pytest.mark.parametrize( + 'input_data_targettest', + ( + 'series_multilabel', + 'series_continuous-multioutput', + ), + indirect=True +) +def test_type_of_target_unsupported(input_data_targettest): + """ + Makes sure we raise a proper message to the user, + when providing not supported data input + """ + validator = TargetValidator() + with pytest.raises(ValueError, match=r"legacy multi-.* data representation."): + validator.fit(input_data_targettest) + + +def test_target_unsupported(): + """ + Makes sure we raise a proper message to the user, + when providing not supported data input + """ + validator = TargetValidator(is_classification=True) + with pytest.raises(ValueError, match=r"The dimensionality of the train and test targets"): + validator.fit( + np.array([[0, 1, 0], [0, 1, 1]]), + np.array([[0, 1, 0, 0], [0, 1, 1, 1]]), + ) + with pytest.raises(ValueError, match=r"Train and test targets must both have the same dtypes"): + validator.fit( + pd.DataFrame({'a': [1, 2, 3]}), + pd.DataFrame({'a': [True, False, False]}), + ) + with pytest.raises(ValueError, match=r"Provided targets are not supported.*"): + validator.fit( + np.array([[0, 1, 2], [0, 3, 4]]), + np.array([[0, 1, 2, 5], [0, 3, 4, 6]]), + ) + with pytest.raises(ValueError, match="Train and test targets must both have the same"): + validator.fit( + pd.DataFrame({'string': ['foo']}), + pd.DataFrame({'int': [1]}), + ) + with pytest.raises(ValueError, match=r"Auto-sklearn only supports Numpy arrays, .*"): + validator.fit({'input1': 1, 'input2': 2}) + with pytest.raises(ValueError, match=r"arget values cannot contain missing/NaN values"): + validator.fit(np.array([np.nan, 1, 2])) + with pytest.raises(ValueError, match=r"arget values cannot contain missing/NaN values"): + validator.fit(sparse.csr_matrix(np.array([1, 2, np.nan]))) + with pytest.raises(ValueError, match=r"Cannot call transform on a validator that is not fit"): + validator.transform(np.array([1, 2, 3])) + with pytest.raises(ValueError, match=r"Cannot call inverse_transform on a validator that is"): + validator.inverse_transform(np.array([1, 2, 3])) + with pytest.raises(ValueError, match=r"Multi-dimensional classification is not yet supported"): + validator._fit(np.array([[1, 2, 3], [1, 5, 6]])) + + # Dia/ DOK are not supported as type of target makes calls len on the array + # which causes TypeError: len() of unsized object. Basically, sparse data as + # multi-label is the only thing that makes sense in this format. + with pytest.raises(ValueError, match=r"The provided data could not be interpreted by Sklearn"): + validator.fit(sparse.dia_matrix(np.array([1, 2, 3]))) + + validator.fit(np.array([[0, 1, 0], [0, 1, 1]])) + with pytest.raises(ValueError, match=r"Number of outputs changed from"): + validator.fit(np.array([0, 1, 0])) + + +def test_targetvalidator_inversetransform(): + """ + Test that the encoding/decoding works in 1D + """ + validator = TargetValidator(is_classification=True) + validator.fit( + pd.DataFrame(data=['a', 'a', 'b', 'c', 'a'], dtype='category'), + ) + y = validator.transform( + pd.DataFrame(data=['a', 'a', 'b', 'c', 'a'], dtype='category'), + ) + np.testing.assert_array_almost_equal(np.array([0, 0, 1, 2, 0]), y) + + y_decoded = validator.inverse_transform(y) + assert ['a', 'a', 'b', 'c', 'a'] == y_decoded.tolist() + + validator = TargetValidator(is_classification=True) + multi_label = pd.DataFrame( + np.array([[1, 0, 0, 1], [0, 0, 1, 1], [0, 0, 0, 0]]), + dtype=bool + ) + validator.fit(multi_label) + y = validator.transform(multi_label) + + y_decoded = validator.inverse_transform(y) + np.testing.assert_array_almost_equal(y, y_decoded) + + +# Actual checks for the targets +@pytest.mark.parametrize( + 'input_data_targettest', + ( + 'series_binary', + 'series_multiclass', + 'pandas_binary', + 'pandas_multiclass', + 'numpy_binary', + 'numpy_multiclass', + 'list_binary', + 'list_multiclass', + ), + indirect=True +) +def test_unknown_categories_in_targets(input_data_targettest): + validator = TargetValidator(is_classification=True) + validator.fit(input_data_targettest) + + # Add an extra category + if isinstance(input_data_targettest, list): + input_data_targettest.append(input_data_targettest[-1] + 5000) + elif isinstance(input_data_targettest, (pd.DataFrame, pd.Series)): + input_data_targettest.iloc[-1] = 5000 + elif isinstance(input_data_targettest, np.ndarray): + input_data_targettest[-1] = 5000 + + x_t = validator.transform(input_data_targettest) + assert x_t[-1].item(0) == -1 + + +def test_is_single_column_target(): + validator = TargetValidator(is_classification=True) + validator.fit(np.array([1, 2, 3, 4])) + assert validator.is_single_column_target() + + validator = TargetValidator(is_classification=True) + validator.fit(np.array([[1, 0, 1, 0], [1, 1, 1, 1]])) + assert not validator.is_single_column_target() diff --git a/test/test_data/test_validation.py b/test/test_data/test_validation.py index 4ed39f9127..47e74a6776 100644 --- a/test/test_data/test_validation.py +++ b/test/test_data/test_validation.py @@ -1,666 +1,139 @@ -import itertools - -import unittest -import unittest.mock - import numpy as np import pandas as pd +import pytest + from scipy import sparse import sklearn.datasets import sklearn.model_selection -from sklearn.utils.multiclass import type_of_target from autosklearn.data.validation import InputValidator -class InputValidatorTest(unittest.TestCase): - - def setUp(self): - self.X = [ - [2.5, 3.3, 2, 5, 1, 1], - [1.0, 0.7, 1, 5, 1, 0], - [1.3, 0.8, 1, 4, 1, 1] - ] - self.y = [0, 1, 0] - - def test_list_input(self): - """ - Makes sure that a list is converted to nparray - """ - validator = InputValidator() - X, y = validator.validate(self.X, self.y) - - self.assertIsInstance(X, np.ndarray) - self.assertIsInstance(y, np.ndarray) - - def test_numpy_input(self): - """ - Makes sure that no encoding is needed for a - numpy float object. Also test features/target - validation methods - """ - validator = InputValidator() - X = validator.validate_features( - np.array(self.X), - ) - y = validator.validate_target( - np.array(self.y) - ) - - self.assertIsInstance(X, np.ndarray) - self.assertIsInstance(y, np.ndarray) - self.assertIsNone(validator.target_encoder) - self.assertIsNone(validator.feature_encoder) - - def test_sparse_numpy_input(self): - """ - Makes sure that no encoder is needed when - working with sparse float data - """ - validator = InputValidator() - - # Sparse data - row_ind = np.array([0, 1, 2]) - col_ind = np.array([1, 2, 1]) - X_sparse = sparse.csr_matrix((np.ones(3), (row_ind, col_ind))) - X = validator.validate_features( - X_sparse, - ) - y = validator.validate_target( - np.array(self.y) - ) - - self.assertIsInstance(X, sparse.csr.csr_matrix) - self.assertIsInstance(y, np.ndarray) - self.assertIsNone(validator.target_encoder) - self.assertIsNone(validator.feature_encoder) - - # Sparse targets should not be supported - data = np.array([1, 2, 3, 4, 5, 6]) - col = np.array([0, 0, 0, 0, 0, 0]) - row = np.array([0, 2, 3, 6, 7, 10]) - y = sparse.csr_matrix((data, (row, col)), shape=(11, 1)) - with self.assertRaisesRegex(ValueError, 'scipy.sparse.csr_matrix.todense'): - validator = InputValidator().validate_target(y) - - def test_dataframe_input_numerical(self): - """ - Makes sure that we don't encode numerical data - """ - for test_type in ['int64', 'float64', 'int8']: - validator = InputValidator() - X = validator.validate_features( - pd.DataFrame(data=self.X, dtype=test_type), - ) - y = validator.validate_target( - pd.DataFrame(data=self.y, dtype=test_type), - ) - - self.assertIsInstance(X, np.ndarray) - self.assertIsInstance(y, np.ndarray) - self.assertIsNone(validator.target_encoder) - self.assertIsNone(validator.feature_encoder) - - def test_dataframe_input_categorical(self): - """ - Makes sure we automatically encode categorical data - """ - for test_type in ['bool', 'category']: - validator = InputValidator() - X = validator.validate_features( - pd.DataFrame(data=self.X, dtype=test_type), - ) - y = validator.validate_target( - pd.DataFrame(data=self.y, dtype=test_type), - is_classification=True, - ) - - self.assertIsInstance(X, np.ndarray) - self.assertIsInstance(y, np.ndarray) - self.assertIsNotNone(validator.target_encoder) - self.assertIsNotNone(validator.feature_encoder) - - def test_binary_conversion(self): - """ - Makes sure that a encoded target for classification - properly retains the binary target type - """ - validator = InputValidator() - - # Just 2 classes, 1 and 2 - y_train = validator.validate_target( - np.array([1.0, 2.0, 2.0, 1.0], dtype=np.float64), - is_classification=True, - ) - self.assertEqual('binary', type_of_target(y_train)) - - # Also make sure that a re-use of the generator is also binary - y_valid = validator.validate_target( - np.array([2.0, 2.0, 2.0, 2.0], dtype=np.float64), - is_classification=True, - ) - self.assertEqual('binary', type_of_target(y_valid)) - - # Make sure binary also works with PD dataframes - validator = InputValidator() - - # Just 2 classes, 1 and 2 - y_train = validator.validate_target( - pd.DataFrame([1.0, 2.0, 2.0, 1.0], dtype='category'), - is_classification=True, - ) - self.assertEqual('binary', type_of_target(y_train)) - - def test_multiclass_conversion(self): - """ - Makes sure that a encoded target for classification - properly retains the multiclass target type - """ - # Multiclass conversion for different datatype - for input_object in [ - [1.0, 2.0, 2.0, 4.0, 3], - np.array([1.0, 2.0, 2.0, 4.0, 3], dtype=np.float64), - pd.DataFrame([1.0, 2.0, 2.0, 4.0, 3], dtype='category'), - ]: - validator = InputValidator() - y_train = validator.validate_target( - input_object, - is_classification=True, - ) - self.assertEqual('multiclass', type_of_target(y_train)) - - def test_multilabel_conversion(self): - """ - Makes sure that a encoded target for classification - properly retains the multilabel target type - """ - # Multi-label conversion for different datatype - for input_object in [ - [[1, 0, 0, 1], [0, 0, 1, 1], [0, 0, 0, 0]], - np.array([[1, 0, 0, 1], [0, 0, 1, 1], [0, 0, 0, 0]]), - pd.DataFrame([[1, 0, 0, 1], [0, 0, 1, 1], [0, 0, 0, 0]], dtype='category'), - ]: - validator = InputValidator() - y_train = validator.validate_target( - input_object, - is_classification=True, - ) - self.assertEqual('multilabel-indicator', type_of_target(y_train)) - - def test_continuous_multioutput_conversion(self): - """ - Makes sure that an input for regression - properly retains the multiout continious target type - """ - # Regression multi out conversion for different datatype - for input_object in [ - [[31.4, 94], [40.5, 109], [25.0, 30]], - np.array([[31.4, 94], [40.5, 109], [25.0, 30]]), - pd.DataFrame([[31.4, 94], [40.5, 109], [25.0, 30]]), - ]: - validator = InputValidator() - y_train = validator.validate_target( - input_object, - is_classification=False, - ) - self.assertEqual('continuous-multioutput', type_of_target(y_train)) - - def test_regression_conversion(self): - """ - Makes sure that a regression input - properly retains the continious target type - """ - for input_object in [ - [1.0, 76.9, 123, 4.0, 81.1], - np.array([1.0, 76.9, 123, 4.0, 81.1]), - pd.DataFrame([1.0, 76.9, 123, 4.0, 81.1]), - ]: - validator = InputValidator() - y_train = validator.validate_target( - input_object, - is_classification=False, - ) - self.assertEqual('continuous', type_of_target(y_train)) - - def test_dataframe_input_unsupported(self): - """ - Makes sure we raise a proper message to the user, - when providing not supported data input - """ - validator = InputValidator() - with self.assertRaisesRegex(ValueError, "Auto-sklearn does not support time"): - validator.validate_features( - pd.DataFrame({'datetime': [pd.Timestamp('20180310')]}) - ) - with self.assertRaisesRegex(ValueError, "has invalid type object"): - validator.validate_features( - pd.DataFrame({'string': ['foo']}) - ) - - validator = InputValidator() - with self.assertRaisesRegex(ValueError, "Expected 2D array, got"): - validator.validate_features({'input1': 1, 'input2': 2}) - - validator = InputValidator() - with self.assertRaisesRegex(ValueError, "Expected 2D array, got"): - validator.validate_features(InputValidator()) - - validator = InputValidator() - X = pd.DataFrame(data=['a', 'b', 'c'], dtype='category') - with unittest.mock.patch('autosklearn.data.validation.InputValidator._check_and_get_columns_to_encode') as mock_foo: # noqa E501 - # Mock that all columns are ok. There should be a - # checker to catch for bugs - mock_foo.return_value = ([], []) - with self.assertRaisesRegex(ValueError, 'Failed to convert the input'): - validator.validate_features(X) - - def test_dataframe_econding_1D(self): - """ - Test that the encoding/decoding works in 1D - """ - validator = InputValidator() - y = validator.validate_target( - pd.DataFrame(data=self.y, dtype=bool), - is_classification=True, - ) - np.testing.assert_array_almost_equal(np.array([0, 1, 0]), y) - - # Result should not change on a multi call - y = validator.validate_target(pd.DataFrame(data=self.y, dtype=bool)) - np.testing.assert_array_almost_equal(np.array([0, 1, 0]), y) - - y_decoded = validator.decode_target(y) - np.testing.assert_array_almost_equal(np.array(self.y, dtype=bool), y_decoded) - - # Now go with categorical data - validator = InputValidator() - y = validator.validate_target( - pd.DataFrame(data=['a', 'a', 'b', 'c', 'a'], dtype='category'), - is_classification=True, - ) - np.testing.assert_array_almost_equal(np.array([0, 0, 1, 2, 0]), y) - - y_decoded = validator.decode_target(y) - self.assertListEqual(['a', 'a', 'b', 'c', 'a'], y_decoded.tolist()) - - def test_dataframe_econding_2D(self): - """ - Test that the encoding/decoding works in 2D - """ - validator = InputValidator() - multi_label = pd.DataFrame( - np.array([[1, 0, 0, 1], [0, 0, 1, 1], [0, 0, 0, 0]]), - dtype=bool - ) - y = validator.validate_target(multi_label, is_classification=True) - - # Result should not change on a multi call - y_new = validator.validate_target(multi_label) - np.testing.assert_array_almost_equal(y_new, y) - - y_decoded = validator.decode_target(y) - np.testing.assert_array_almost_equal(y, y_decoded) - - def test_noNaN(self): - """ - Makes sure that during classification/regression task, - the transformed data is not corrupted. - - Testing is given without Nan and no sparse data - """ - # numpy - categorical - classification - x = np.array(['a', 'b', 'c', 'a', 'b', 'c']).reshape(-1, 1) - validator = InputValidator() - with self.assertRaisesRegex(ValueError, - 'the only valid dtypes are numerical ones'): - x_t, y_t = validator.validate(x, np.copy(x), is_classification=True) - - # numpy - categorical - regression - with self.assertRaisesRegex(ValueError, - 'the only valid dtypes are numerical ones'): - x_t, y_t = validator.validate(x, np.copy(x), is_classification=False) - - # numpy - numerical - classification - x = np.random.random_sample((4, 4)) - y = np.random.choice([0, 1], 4) - validator = InputValidator() - x_t, y_t = validator.validate(x, y, is_classification=True) - self.assertTrue(np.issubdtype(x_t.dtype, np.number)) - self.assertTrue(np.issubdtype(y_t.dtype, np.number)) - self.assertEqual(type_of_target(y_t), 'binary') - self.assertTupleEqual(np.shape(x), np.shape(x_t)) - self.assertTupleEqual(np.shape(y), np.shape(y_t)) - - # numpy - numerical - regression - x = np.random.random_sample((4, 4)) - y = np.random.random_sample(4) - validator = InputValidator() - x_t, y_t = validator.validate(x, y, is_classification=False) - np.testing.assert_array_equal(x, x_t) # No change to valid data - np.testing.assert_array_equal(y, y_t) - self.assertEqual(type_of_target(y_t), 'continuous') - - # pandas - categorical - classification - x = pd.DataFrame({'A': np.random.choice(['a', 'b'], 4), - 'B': np.random.choice(['a', 'b'], 4)}, - dtype='category') - y = pd.DataFrame(np.random.choice(['c', 'd'], 4), dtype='category') - validator = InputValidator() - x_t, y_t = validator.validate(x, y, is_classification=True) - self.assertTrue(np.issubdtype(x_t.dtype, np.number)) - self.assertTrue(np.issubdtype(y_t.dtype, np.number)) - self.assertEqual(type_of_target(y_t), 'binary') - self.assertTupleEqual(np.shape(x), np.shape(x_t)) - self.assertTupleEqual(np.shape(y.to_numpy().reshape(-1)), np.shape(y_t)) # ravel - - # pandas - categorical - regression - x = pd.DataFrame({'A': np.random.choice(['a', 'b'], 4), - 'B': np.random.choice(['a', 'b'], 4)}, - dtype='category') - y = pd.DataFrame(np.random.random_sample(4)) - validator = InputValidator() - x_t, y_t = validator.validate(x, y, is_classification=False) - self.assertTrue(np.issubdtype(x_t.dtype, np.number)) - self.assertTrue(np.issubdtype(y_t.dtype, np.number)) - self.assertEqual(type_of_target(y_t), 'continuous') - self.assertTupleEqual(np.shape(x), np.shape(x_t)) - np.testing.assert_array_equal(y.to_numpy().reshape(-1), y_t) - self.assertTupleEqual(np.shape(y.to_numpy().reshape(-1)), np.shape(y_t)) # ravel version - - # pandas - numerical - classification - x = pd.DataFrame({'A': np.random.random_sample(4), - 'B': np.random.choice([2.5, 1.2], 4)}) - y = pd.DataFrame([1.0, 2.2, 3.2, 2.2]) - validator = InputValidator() - x_t, y_t = validator.validate(x, y, is_classification=True) - self.assertTrue(np.issubdtype(x_t.dtype, np.number)) - self.assertTrue(np.issubdtype(y_t.dtype, np.number)) - self.assertEqual(type_of_target(y_t), 'multiclass') - self.assertTupleEqual(np.shape(x), np.shape(x_t)) - np.testing.assert_array_equal(np.array([0, 1, 2, 1]), y_t) - self.assertTupleEqual(np.shape(y.to_numpy().reshape(-1)), np.shape(y_t)) # ravel - - # pandas - numerical - regression - x = pd.DataFrame({'A': np.random.choice([1.5, 3.6], 4), - 'B': np.random.choice([2.5, 1.2], 4)}) - y = pd.DataFrame(np.random.random_sample(4)) - validator = InputValidator() - x_t, y_t = validator.validate(x, y, is_classification=False) - self.assertTrue(np.issubdtype(x_t.dtype, np.number)) - self.assertTrue(np.issubdtype(y_t.dtype, np.number)) - self.assertEqual(type_of_target(y_t), 'continuous') - self.assertTupleEqual(np.shape(x), np.shape(x_t)) - self.assertTupleEqual(np.shape(y.to_numpy().reshape(-1)), np.shape(y_t)) # ravel - np.testing.assert_array_equal(y.to_numpy().reshape(-1), y_t) - return - - def test_NaN(self): - # numpy - categorical - classification - # np.nan in categorical array means that the array will be - # type string, and np.nan will be casted as 'nan'. - # In turn, 'nan' will be another category - x = np.array([1, 2, 3, 4, 5.0, np.nan]).reshape(-1, 1) - y = np.array([1, 2, 3, 4, 5.0, 6.0]).reshape(-1, 1) - validator = InputValidator() - x_t, y_t = validator.validate(x, y, is_classification=True) - self.assertTrue(np.issubdtype(x_t.dtype, np.number)) - self.assertTrue(np.issubdtype(y_t.dtype, np.number)) - self.assertTrue(np.isnan(x_t).any()) # Preserve NaN in features - self.assertEqual(type_of_target(y_t), 'multiclass') - self.assertTupleEqual(np.shape(x), np.shape(x_t)) - - # numpy - categorical - regression - # nan in target should raise error - y = np.random.random_sample((6, 1)) - y[1] = np.nan - with self.assertRaisesRegex(ValueError, 'Target values cannot contain missing/NaN'): - InputValidator().validate_target(y) - - # numpy - numerical - classification - # Numerical numpy features should continue without encoding - # categorical encoding of Nan for the targets is not supported - x = np.random.random_sample((4, 4)) - x[3] = np.nan - y = np.random.choice([0.0, 1.0], 4) - y[1] = np.nan - x_t = InputValidator().validate_features(x) - self.assertTrue(np.issubdtype(x_t.dtype, np.number)) - self.assertTrue(np.isnan(x_t).any()) - self.assertEqual(type_of_target(y_t), 'multiclass') - self.assertTupleEqual(np.shape(x), np.shape(x_t)) - - with self.assertRaisesRegex(ValueError, 'Target values cannot contain missing/NaN'): - InputValidator().validate_target(y, is_classification=True) - - with self.assertRaisesRegex(ValueError, 'Target values cannot contain missing/NaN'): - InputValidator().validate_target(y, is_classification=False) - - # Make sure we allow NaN in numerical columns - x_only_numerical = np.random.random_sample(4) - x[3] = np.nan - x_only_numerical = pd.DataFrame(data={'A': x_only_numerical, 'B': x_only_numerical*2}) - try: - InputValidator().validate_features(x_only_numerical) - except ValueError: - self.fail("NaN values in numerical columns is allowed") - - # Make sure we do not allow NaN in categorical columns - x_only_categorical = pd.DataFrame(data=pd.Series([1, 2, pd.NA], dtype="category")) - with self.assertRaisesRegex(ValueError, 'Categorical features in a dataframe cannot'): - InputValidator().validate_features(x_only_categorical) - - y = np.random.choice([0.0, 1.0], 4) - y[1] = np.nan - y = pd.DataFrame(y) - - with self.assertRaisesRegex(ValueError, 'Target values cannot contain missing/NaN'): - InputValidator().validate_target(y, is_classification=True) - - with self.assertRaisesRegex(ValueError, 'Target values cannot contain missing/NaN'): - InputValidator().validate_target(y, is_classification=False) - return - - def test_no_new_category_after_fit(self): - # First make sure no problem if no categorical - x = pd.DataFrame({'A': [1, 2, 3, 4], 'B': [5, 6, 7, 8]}) - y = pd.DataFrame([1, 2, 3, 4]) - validator = InputValidator() - validator.validate(x, y, is_classification=True) - validator.validate_features(x) - x['A'] = x['A'].apply(lambda x: x*x) - validator.validate_features(x) - - # Then make sure we catch categorical extra categories - x = pd.DataFrame({'A': [1, 2, 3, 4], 'B': [5, 6, 7, 8]}, dtype='category') - y = pd.DataFrame([1, 2, 3, 4]) - validator = InputValidator() - validator.validate(x, y, is_classification=True) - validator.validate_features(x) - x['A'] = x['A'].apply(lambda x: x*x) - with self.assertRaisesRegex( - ValueError, - 'During fit, the input features contained categorical values' - ): - validator.validate_features(x) - - # For label encoder of targets - with self.assertRaisesRegex( - ValueError, - 'During fit, the target array contained the categorical' - ): - validator.validate_target(pd.DataFrame([1, 2, 5, 4])) - - # For ordinal encoder of targets - x = pd.DataFrame({'A': [1, 2, 3, 4], 'B': [5, 6, 7, 8]}, dtype='category') - validator = InputValidator() - validator.validate(x, x, is_classification=True) - validator.validate_target(pd.DataFrame( - {'A': [1, 2, 3, 4], 'B': [5, 6, 7, 8]}, dtype='category') - ) - with self.assertRaisesRegex( - ValueError, - 'During fit, the target array contained the categorical' - ): - validator.validate_target(pd.DataFrame( - {'A': [1, 2, 3, 4], 'B': [5, 9, 7, 8]}, dtype='category') - ) - return - - def test_big_dataset_encoding(self): - x, y = sklearn.datasets.fetch_openml(data_id=2, return_X_y=True, as_frame=True) - validator = InputValidator() - - with self.assertRaisesRegex( - ValueError, - 'Categorical features in a dataframe cannot contain missing/NaN' - ): - x_t, y_t = validator.validate(x, y, is_classification=True) - - # Make sure translation works apart from Nan +@pytest.mark.parametrize('openmlid', [2, 40975, 40984]) +@pytest.mark.parametrize('as_frame', [True, False]) +def test_data_validation_for_classification(openmlid, as_frame): + x, y = sklearn.datasets.fetch_openml(data_id=openmlid, return_X_y=True, as_frame=as_frame) + validator = InputValidator(is_classification=True) + if as_frame: # NaN is not supported in categories, so - # drop columns with them. Also, do a proof of concept - # that all nan column is preserved, so that the pipeline deal - # with it - x = x.dropna('columns', 'any') - x.insert(len(x.columns), 'NaNColumn', np.nan, True) - x_t, y_t = validator.validate(x, y, is_classification=True) - self.assertTupleEqual(np.shape(x), np.shape(x_t)) - - self.assertTrue(np.all(pd.isnull(x_t[:, -1]))) - - # Leave columns that are complete NaN - # The sklearn pipeline will handle that - self.assertTrue(np.isnan(x_t).any()) - np.testing.assert_array_equal( - pd.isnull(x.dropna(axis='columns', how='all')), - pd.isnull(x.dropna(axis='columns', how='any')) - ) - - # make sure everything was encoded to number - self.assertTrue(np.issubdtype(x_t.dtype, np.number)) - - # No change to numerical columns - np.testing.assert_array_equal(x['carbon'].to_numpy(), x_t[:, 3]) - - # Categorical columns are sorted to the beginning - self.assertEqual( - validator.feature_types, - (['categorical'] * 3) + (['numerical'] * 7) - ) - self.assertEqual(x.iloc[0, 6], 610) - np.testing.assert_array_equal(x_t[0], [0, 0, 0, 8, 0, 0, 0.7, 610, 0, np.NaN]) - - return - - def test_join_and_check(self): - validator = InputValidator() - - # Numpy Testing - y = np.array([2, 2, 3, 4, 5]) - y_test = np.array([3, 4, 5, 6, 1]) - - joined = validator.join_and_check(y, y_test) - np.testing.assert_array_equal( - joined, - np.array([2, 2, 3, 4, 5, 3, 4, 5, 6, 1]) - ) - - validator.validate_target(joined, is_classification=True) - y_encoded = validator.validate_target(y) - y_test_encoded = validator.validate_target(y_test) - - # If a common encoding happened, then common elements - # should have a common encoding - self.assertEqual(y_encoded[2], y_test_encoded[0]) - - # Pandas Testing - validator = InputValidator() - joined = validator.join_and_check( - pd.DataFrame(y), - pd.DataFrame(y_test) - ) - np.testing.assert_array_equal( - joined, - pd.DataFrame([2, 2, 3, 4, 5, 3, 4, 5, 6, 1]) - ) - - # List Testing - validator = InputValidator() - joined = validator.join_and_check( - [2, 2, 3, 4, 5], - [3, 4, 5, 6, 1] - ) - np.testing.assert_array_equal( - joined, - [2, 2, 3, 4, 5, 3, 4, 5, 6, 1] - ) - - # Make sure some messages are triggered - y = np.array([[1, 0, 0, 1], [0, 0, 1, 1], [0, 0, 0, 0]]) - y_test = np.array([3, 4, 5, 6, 1]) - with self.assertRaisesRegex( - ValueError, - 'Train and test targets must have the same dimensionality' - ): - joined = validator.join_and_check(y, y_test) - with self.assertRaisesRegex( - ValueError, - 'Train and test targets must be of the same type' - ): - joined = validator.join_and_check(y, pd.DataFrame(y_test)) - - def test_big_dataset_encoding2(self): - """ - Makes sure that when there are multiple classes, - and test/train targets differ, we proactively encode together - the data between test and train - """ - X, y = sklearn.datasets.fetch_openml(data_id=183, return_X_y=True, as_frame=True) - X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( - X, - y, - random_state=1 - ) - - # Make sure this test makes sense, so that y_test - # and y_train have different classes - all_classes = set(np.unique(y_test)).union(set(np.unique(y_train))) - elements_in_test_only = np.setdiff1d(np.unique(y_test), np.unique(y_train)) - self.assertGreater(len(elements_in_test_only), 0) - - validator = InputValidator() - common = validator.join_and_check( - pd.DataFrame(y), - pd.DataFrame(y_test) + # drop columns with them. + nan_cols = [i for i in x.columns if x[i].isnull().any()] + cat_cols = [i for i in x.columns if x[i].dtype.name in ['category', 'bool']] + unsupported_columns = list(set(nan_cols) & set(cat_cols)) + if len(unsupported_columns) > 0: + x.drop(unsupported_columns, axis=1, inplace=True) + + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + x, y, test_size=0.33, random_state=0) + + validator.fit(X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test) + + X_train_t, y_train_t = validator.transform(X_train, y_train) + assert np.shape(X_train) == np.shape(X_train_t) + + # Leave columns that are complete NaN + # The sklearn pipeline will handle that + if as_frame and np.any(pd.isnull(X_train).values.all(axis=0)): + assert np.any(pd.isnull(X_train_t).values.all(axis=0)) + elif not as_frame and np.any(pd.isnull(X_train).all(axis=0)): + assert np.any(pd.isnull(X_train_t).all(axis=0)) + + # make sure everything was encoded to number + assert np.issubdtype(X_train_t.dtype, np.number) + assert np.issubdtype(y_train_t.dtype, np.number) + + # Categorical columns are sorted to the beginning + if as_frame: + validator.feature_validator.feat_type is not None + ordered_unique_elements = list(dict.fromkeys(validator.feature_validator.feat_type)) + if len(ordered_unique_elements) > 1: + assert ordered_unique_elements[0] == 'categorical' + + +@pytest.mark.parametrize('openmlid', [505, 546, 531]) +@pytest.mark.parametrize('as_frame', [True, False]) +def test_data_validation_for_regression(openmlid, as_frame): + x, y = sklearn.datasets.fetch_openml(data_id=openmlid, return_X_y=True, as_frame=as_frame) + validator = InputValidator(is_classification=False) + + if as_frame: + # NaN is not supported in categories, so + # drop columns with them. + nan_cols = [i for i in x.columns if x[i].isnull().any()] + cat_cols = [i for i in x.columns if x[i].dtype.name in ['category', 'bool']] + unsupported_columns = list(set(nan_cols) & set(cat_cols)) + if len(unsupported_columns) > 0: + x.drop(unsupported_columns, axis=1, inplace=True) + + X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + x, y, test_size=0.33, random_state=0) + + validator.fit(X_train=X_train, y_train=y_train) + + X_train_t, y_train_t = validator.transform(X_train, y_train) + assert np.shape(X_train) == np.shape(X_train_t) + + # Leave columns that are complete NaN + # The sklearn pipeline will handle that + if as_frame and np.any(pd.isnull(X_train).values.all(axis=0)): + assert np.any(pd.isnull(X_train_t).values.all(axis=0)) + elif not as_frame and np.any(pd.isnull(X_train).all(axis=0)): + assert np.any(pd.isnull(X_train_t).all(axis=0)) + + # make sure everything was encoded to number + assert np.issubdtype(X_train_t.dtype, np.number) + assert np.issubdtype(y_train_t.dtype, np.number) + + # Categorical columns are sorted to the beginning + if as_frame: + validator.feature_validator.feat_type is not None + ordered_unique_elements = list(dict.fromkeys(validator.feature_validator.feat_type)) + if len(ordered_unique_elements) > 1: + assert ordered_unique_elements[0] == 'categorical' + + +def test_sparse_data_validation_for_regression(): + X, y = sklearn.datasets.make_regression(n_samples=100, n_features=50, random_state=0) + X_sp = sparse.coo_matrix(X) + validator = InputValidator(is_classification=False) + + validator.fit(X_train=X_sp, y_train=y) + + X_t, y_t = validator.transform(X, y) + assert np.shape(X) == np.shape(X_t) + + # make sure everything was encoded to number + assert np.issubdtype(X_t.dtype, np.number) + assert np.issubdtype(y_t.dtype, np.number) + + # Make sure we can change the sparse format + X_t, y_t = validator.transform(sparse.csr_matrix(X), y) + + +def test_validation_unsupported(): + """ + Makes sure we raise a proper message to the user, + when providing not supported data input + """ + validator = InputValidator() + with pytest.raises(ValueError, match=r"Inconsistent number of train datapoints.*"): + validator.fit( + X_train=np.array([[0, 1, 0], [0, 1, 1]]), + y_train=np.array([0, 1, 0, 0, 0, 0]), + ) + with pytest.raises(ValueError, match=r"Inconsistent number of test datapoints.*"): + validator.fit( + X_train=np.array([[0, 1, 0], [0, 1, 1]]), + y_train=np.array([0, 1]), + X_test=np.array([[0, 1, 0], [0, 1, 1]]), + y_test=np.array([0, 1, 0, 0, 0, 0]), + ) + with pytest.raises(ValueError, match=r"Cannot call transform on a validator .*fitted"): + validator.transform( + X=np.array([[0, 1, 0], [0, 1, 1]]), + y=np.array([0, 1]), ) - - validator.validate_target(common, is_classification=True) - - encoded_classes = validator.target_encoder.classes_ - missing = all_classes - set(encoded_classes) - self.assertEqual(len(missing), 0) - - def test_all_posible_dtype_changes(self): - """We do not allow a change in dtype once inputvalidator - is fitted""" - data = [[1, 0, 1], [1, 1, 1]] - type_perms = list(itertools.permutations([ - data, - np.array(data), - pd.DataFrame(data) - ], r=2)) - - for first, second in type_perms: - validator = InputValidator() - validator.validate_target(first) - with self.assertRaisesRegex(ValueError, - "Auto-sklearn previously received targets of type"): - validator.validate_target(second) - validator.validate_features(first) - with self.assertRaisesRegex(ValueError, - "Auto-sklearn previously received features of type"): - validator.validate_features(second) diff --git a/test/test_util/test_hash.py b/test/test_util/test_hash.py deleted file mode 100644 index ba7cf026d7..0000000000 --- a/test/test_util/test_hash.py +++ /dev/null @@ -1,62 +0,0 @@ -import unittest - -import numpy as np -import scipy.sparse - -from autosklearn.util.hash import hash_array_or_matrix - - -class HashTests(unittest.TestCase): - - def test_c_contiguous_array(self): - array = np.array([[1, 2], [3, 4]]) - - hash = hash_array_or_matrix(array) - - self.assertIsNotNone(hash) - - def test_f_contiguous_array(self): - array = np.array([[1, 2], [3, 4]]) - array = np.asfortranarray(array) - - hash = hash_array_or_matrix(array) - - self.assertIsNotNone(hash) - - def test_transpose_arrays(self): - c_array = np.array([[1, 2], [3, 4]]) - f_array = np.array([[1, 3], [2, 4]]) - f_array = np.asfortranarray(f_array) - - c_hash = hash_array_or_matrix(c_array) - f_hash = hash_array_or_matrix(f_array) - - self.assertEqual(c_hash, f_hash) - - def test_same_data_arrays(self): - first_array = np.array([[1, 2], [3, 4]]) - second_array = np.array([[1, 2], [3, 4]]) - - first_hash = hash_array_or_matrix(first_array) - second_hash = hash_array_or_matrix(second_array) - - self.assertEqual(first_hash, second_hash) - - def test_different_data_arrays(self): - first_array = np.array([[1, 2], [3, 4]]) - second_array = np.array([[1, 3], [2, 4]]) - - first_hash = hash_array_or_matrix(first_array) - second_hash = hash_array_or_matrix(second_array) - - self.assertNotEqual(first_hash, second_hash) - - def test_scipy_csr(self): - row = np.array([0, 0, 1, 2, 2, 2]) - col = np.array([0, 2, 2, 0, 1, 2]) - data = np.array([1, 2, 3, 4, 5, 6]) - matrix = scipy.sparse.csr_matrix((data, (row, col)), shape=(3, 3)) - - hash = hash_array_or_matrix(matrix) - - self.assertIsNotNone(hash) From 1f5bb5a3556d6f12a465e33c5018a1e96cca8390 Mon Sep 17 00:00:00 2001 From: Matthias Feurer Date: Fri, 15 Jan 2021 13:21:05 +0100 Subject: [PATCH 06/10] make two warnings an info --- autosklearn/metalearning/metalearning/kNearestDatasets/kND.py | 3 +-- .../metalearning/optimizers/metalearn_optimizer/metalearner.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/autosklearn/metalearning/metalearning/kNearestDatasets/kND.py b/autosklearn/metalearning/metalearning/kNearestDatasets/kND.py index b0d1a02933..336d3b6bb3 100644 --- a/autosklearn/metalearning/metalearning/kNearestDatasets/kND.py +++ b/autosklearn/metalearning/metalearning/kNearestDatasets/kND.py @@ -139,8 +139,7 @@ def kBestSuggestions(self, x, k=1, exclude_double_configurations=True): dataset_name] if best_configuration is None: - self.logger.warning("Found no best configuration for instance " - "%s" % dataset_name) + self.logger.info("Found no best configuration for instance %s" % dataset_name) continue if exclude_double_configurations: diff --git a/autosklearn/metalearning/optimizers/metalearn_optimizer/metalearner.py b/autosklearn/metalearning/optimizers/metalearn_optimizer/metalearner.py index 31eba70167..6092343a7a 100644 --- a/autosklearn/metalearning/optimizers/metalearn_optimizer/metalearner.py +++ b/autosklearn/metalearning/optimizers/metalearn_optimizer/metalearner.py @@ -110,7 +110,7 @@ def _learn(self, exclude_double_configurations=True): runs[task_id] = self.meta_base.get_runs(task_id) except KeyError: # TODO should I really except this? - self.logger.warning("Could not find runs for instance %s" % task_id) + self.logger.info("Could not find runs for instance %s" % task_id) runs[task_id] = pd.Series([], name=task_id) runs = pd.DataFrame(runs) From 05e7263c6fbfa46e30bb0f4d9fe2389366af940b Mon Sep 17 00:00:00 2001 From: Francisco Rivera Valverde <44504424+franchuterivera@users.noreply.github.com> Date: Mon, 18 Jan 2021 12:49:14 +0100 Subject: [PATCH 07/10] Remove the __main__ restriction from examples (#1053) * Allow to pass the context to each block * Re-enable examples without main * Intensification supports just 1 worker * spawn if provided dask client * Mock dask client * Flake fix * distributed wait redundant * Error out on unsupported calls --- .gitignore | 3 + autosklearn/automl.py | 22 +- autosklearn/ensemble_builder.py | 9 +- autosklearn/smbo.py | 4 + autosklearn/util/single_thread_client.py | 86 ++++ examples/20_basic/example_classification.py | 65 ++- .../example_multilabel_classification.py | 117 +++-- examples/20_basic/example_regression.py | 51 +-- .../example_calc_multiple_metrics.py | 70 +-- examples/40_advanced/example_feature_types.py | 75 ++- .../example_get_pipeline_components.py | 335 +++++++------- examples/40_advanced/example_metrics.py | 298 ++++++------ .../40_advanced/example_pandas_train_test.py | 187 ++++---- examples/40_advanced/example_resampling.py | 221 +++++---- examples/60_search/example_random_search.py | 201 ++++---- examples/60_search/example_sequential.py | 85 ++-- .../60_search/example_successive_halving.py | 433 +++++++++--------- .../example_extending_classification.py | 59 ++- .../example_extending_preprocessor.py | 61 ++- .../example_extending_regression.py | 57 ++- ...mple_restrict_number_of_hyperparameters.py | 223 +++++---- test/test_util/test_single_thread_client.py | 27 ++ 22 files changed, 1411 insertions(+), 1278 deletions(-) create mode 100644 autosklearn/util/single_thread_client.py create mode 100644 test/test_util/test_single_thread_client.py diff --git a/.gitignore b/.gitignore index 58595df4f7..92fa37b152 100755 --- a/.gitignore +++ b/.gitignore @@ -75,3 +75,6 @@ coverage.xml *,cover .hypothesis/ prof/ + +# Mypy +.mypy_cache/ diff --git a/autosklearn/automl.py b/autosklearn/automl.py index d9119beb6e..0dbaa78b56 100644 --- a/autosklearn/automl.py +++ b/autosklearn/automl.py @@ -64,6 +64,7 @@ CoalescenseChoice ) from autosklearn.pipeline.components.data_preprocessing.rescaling import RescalingChoice +from autosklearn.util.single_thread_client import SingleThreadedClient def _model_predict(model, X, batch_size, logger, task): @@ -222,6 +223,16 @@ def __init__(self, # The ensemble performance history through time self.ensemble_performance_history = [] + # Single core, local runs should use fork + # to prevent the __main__ requirements in + # examples. Nevertheless, multi-process runs + # have spawn as requirement to reduce the + # possibility of a deadlock + self._multiprocessing_context = 'spawn' + if self._n_jobs == 1 and self._dask_client is None: + self._multiprocessing_context = 'fork' + self._dask_client = SingleThreadedClient() + if not isinstance(self._time_for_task, int): raise ValueError("time_left_for_this_task not of type integer, " "but %s" % str(type(self._time_for_task))) @@ -241,7 +252,7 @@ def _create_dask_client(self): self._dask_client = dask.distributed.Client( dask.distributed.LocalCluster( n_workers=self._n_jobs, - processes=True, + processes=True if self._n_jobs != 1 else False, threads_per_worker=1, # We use the temporal directory to save the # dask workers, because deleting workers @@ -288,7 +299,8 @@ def _get_logger(self, name): # under the above logging configuration setting # We need to specify the logger_name so that received records # are treated under the logger_name ROOT logger setting - context = multiprocessing.get_context('spawn') + context = multiprocessing.get_context( + self._multiprocessing_context) self.stop_logging_server = context.Event() port = context.Value('l') # be safe by using a long port.value = -1 @@ -389,6 +401,7 @@ def _do_dummy_prediction(self, datamanager, num_run): abort_on_first_run_crash=False, cost_for_crash=get_cost_of_crash(self._metric), port=self._logger_port, + pynisher_context=self._multiprocessing_context, **self._resampling_strategy_arguments) status, cost, runtime, additional_info = ta.run(num_run, cutoff=self._time_for_task) @@ -558,6 +571,7 @@ def fit( self._logger.debug(' resampling_strategy_arguments: %s', str(self._resampling_strategy_arguments)) self._logger.debug(' n_jobs: %s', str(self._n_jobs)) + self._logger.debug(' multiprocessing_context: %s', str(self._multiprocessing_context)) self._logger.debug(' dask_client: %s', str(self._dask_client)) self._logger.debug(' precision: %s', str(self.precision)) self._logger.debug(' disable_evaluator_output: %s', str(self._disable_evaluator_output)) @@ -667,6 +681,7 @@ def fit( ensemble_memory_limit=self._memory_limit, random_state=self._seed, logger_port=self._logger_port, + pynisher_context=self._multiprocessing_context, ) self._stopwatch.stop_task(ensemble_task_name) @@ -742,6 +757,7 @@ def fit( smac_scenario_args=self._smac_scenario_args, scoring_functions=self._scoring_functions, port=self._logger_port, + pynisher_context=self._multiprocessing_context, ensemble_callback=proc_ensemble, ) @@ -1029,10 +1045,10 @@ def fit_ensemble(self, y, task=None, precision=32, ensemble_memory_limit=self._memory_limit, random_state=self._seed, logger_port=self._logger_port, + pynisher_context=self._multiprocessing_context, ) manager.build_ensemble(self._dask_client) future = manager.futures.pop() - dask.distributed.wait([future]) # wait for the ensemble process to finish result = future.result() if result is None: raise ValueError("Error building the ensemble - please check the log file and command " diff --git a/autosklearn/ensemble_builder.py b/autosklearn/ensemble_builder.py index 2a8cc42a3c..3b1b2d8241 100644 --- a/autosklearn/ensemble_builder.py +++ b/autosklearn/ensemble_builder.py @@ -58,6 +58,7 @@ def __init__( ensemble_memory_limit: Optional[int], random_state: int, logger_port: int = logging.handlers.DEFAULT_TCP_LOGGING_PORT, + pynisher_context: str = 'fork', ): """ SMAC callback to handle ensemble building @@ -105,6 +106,8 @@ def __init__( read at most n new prediction files in each iteration logger_port: int port that receives logging records + pynisher_context: str + The multiprocessing context for pynisher. One of spawn/fork/forkserver. Returns ------- @@ -128,6 +131,7 @@ def __init__( self.ensemble_memory_limit = ensemble_memory_limit self.random_state = random_state self.logger_port = logger_port + self.pynisher_context = pynisher_context # Store something similar to SMAC's runhistory self.history = [] @@ -155,7 +159,6 @@ def __call__( def build_ensemble( self, dask_client: dask.distributed.Client, - pynisher_context: str = 'spawn', unit_test: bool = False ) -> None: @@ -229,7 +232,7 @@ def build_ensemble( iteration=self.iteration, return_predictions=False, priority=100, - pynisher_context=pynisher_context, + pynisher_context=self.pynisher_context, logger_port=self.logger_port, unit_test=unit_test, )) @@ -573,7 +576,7 @@ def run( end_at: Optional[float] = None, time_buffer=5, return_predictions: bool = False, - pynisher_context: str = 'spawn', # only change for unit testing! + pynisher_context: str = 'spawn', ): if time_left is None and end_at is None: diff --git a/autosklearn/smbo.py b/autosklearn/smbo.py index fd82a95a32..5ab303b4fd 100644 --- a/autosklearn/smbo.py +++ b/autosklearn/smbo.py @@ -224,6 +224,7 @@ def __init__(self, config_space, dataset_name, smac_scenario_args=None, get_smac_object_callback=None, scoring_functions=None, + pynisher_context='spawn', ensemble_callback: typing.Optional[EnsembleBuilderManager] = None, ): super(AutoMLSMBO, self).__init__() @@ -269,6 +270,8 @@ def __init__(self, config_space, dataset_name, self.get_smac_object_callback = get_smac_object_callback self.scoring_functions = scoring_functions + self.pynisher_context = pynisher_context + self.ensemble_callback = ensemble_callback dataset_name_ = "" if dataset_name is None else dataset_name @@ -448,6 +451,7 @@ def run_smbo(self): disable_file_output=self.disable_file_output, scoring_functions=self.scoring_functions, port=self.port, + pynisher_context=self.pynisher_context, **self.resampling_strategy_args ) ta = ExecuteTaFuncWithQueue diff --git a/autosklearn/util/single_thread_client.py b/autosklearn/util/single_thread_client.py new file mode 100644 index 0000000000..5cd7c653f4 --- /dev/null +++ b/autosklearn/util/single_thread_client.py @@ -0,0 +1,86 @@ +import typing +from pathlib import Path + +import dask.distributed + + +class DummyFuture(dask.distributed.Future): + """ + A class that mimics a distributed Future, the outcome of + performing submit on a distributed client. + """ + def __init__(self, result: typing.Any) -> None: + self._result = result # type: typing.Any + + def result(self, timeout: typing.Optional[int] = None) -> typing.Any: + return self._result + + def cancel(self) -> None: + pass + + def done(self) -> bool: + return True + + def __repr__(self) -> str: + return "DummyFuture: {}".format(self._result) + + def __del__(self) -> None: + pass + + +class SingleThreadedClient(dask.distributed.Client): + """ + A class to Mock the Distributed Client class, in case + Auto-Sklearn is meant to run in the current Thread. + """ + def __init__(self) -> None: + + # Raise a not implemented error if using a method from Client + implemented_methods = ['submit', 'close', 'shutdown', 'write_scheduler_file', + '_get_scheduler_info', 'nthreads'] + method_list = [func for func in dir(dask.distributed.Client) if callable( + getattr(dask.distributed.Client, func)) and not func.startswith('__')] + for method in method_list: + if method in implemented_methods: + continue + setattr(self, method, self._unsupported_method) + pass + + def _unsupported_method(self) -> None: + raise NotImplementedError() + + def submit( + self, + func: typing.Callable, + *args: typing.List, + priority: int = 0, + **kwargs: typing.Dict, + ) -> typing.Any: + return DummyFuture(func(*args, **kwargs)) + + def close(self) -> None: + pass + + def shutdown(self) -> None: + pass + + def write_scheduler_file(self, scheduler_file: str) -> None: + Path(scheduler_file).touch() + return + + def _get_scheduler_info(self) -> typing.Dict: + return { + 'workers': ['127.0.0.1'], + 'type': 'Scheduler', + } + + def nthreads(self) -> typing.Dict: + return { + '127.0.0.1': 1, + } + + def __repr__(self) -> str: + return 'SingleThreadedClient()' + + def __del__(self) -> None: + pass diff --git a/examples/20_basic/example_classification.py b/examples/20_basic/example_classification.py index d7949d1065..9ed8d5bf7b 100644 --- a/examples/20_basic/example_classification.py +++ b/examples/20_basic/example_classification.py @@ -13,36 +13,35 @@ import autosklearn.classification -if __name__ == "__main__": - ############################################################################ - # Data Loading - # ============ - - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1) - - ############################################################################ - # Build and fit a regressor - # ========================= - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=120, - per_run_time_limit=30, - tmp_folder='/tmp/autosklearn_classification_example_tmp', - output_folder='/tmp/autosklearn_classification_example_out', - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - ############################################################################ - # Print the final ensemble constructed by auto-sklearn - # ==================================================== - - print(automl.show_models()) - - ########################################################################### - # Get the Score of the final ensemble - # =================================== - - predictions = automl.predict(X_test) - print("Accuracy score:", sklearn.metrics.accuracy_score(y_test, predictions)) +############################################################################ +# Data Loading +# ============ + +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) + +############################################################################ +# Build and fit a regressor +# ========================= + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=120, + per_run_time_limit=30, + tmp_folder='/tmp/autosklearn_classification_example_tmp', + output_folder='/tmp/autosklearn_classification_example_out', +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +############################################################################ +# Print the final ensemble constructed by auto-sklearn +# ==================================================== + +print(automl.show_models()) + +########################################################################### +# Get the Score of the final ensemble +# =================================== + +predictions = automl.predict(X_test) +print("Accuracy score:", sklearn.metrics.accuracy_score(y_test, predictions)) diff --git a/examples/20_basic/example_multilabel_classification.py b/examples/20_basic/example_multilabel_classification.py index 00b1a6bae2..30f9be498b 100644 --- a/examples/20_basic/example_multilabel_classification.py +++ b/examples/20_basic/example_multilabel_classification.py @@ -16,62 +16,61 @@ import autosklearn.classification -if __name__ == "__main__": - ############################################################################ - # Data Loading - # ============ - - # Using reuters multilabel dataset -- https://www.openml.org/d/40594 - X, y = sklearn.datasets.fetch_openml(data_id=40594, return_X_y=True, as_frame=False) - - # fetch openml downloads a numpy array with TRUE/FALSE strings. Re-map it to - # integer dtype with ones and zeros - # This is to comply with Scikit-learn requirement: - # "Positive classes are indicated with 1 and negative classes with 0 or -1." - # More information on: https://scikit-learn.org/stable/modules/multiclass.html - y[y == 'TRUE'] = 1 - y[y == 'FALSE'] = 0 - y = y.astype(np.int) - - # Using type of target is a good way to make sure your data - # is properly formatted - print(f"type_of_target={type_of_target(y)}") - - X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( - X, y, random_state=1 - ) - - ############################################################################ - # Building the classifier - # ======================= - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=60, - per_run_time_limit=30, - # Bellow two flags are provided to speed up calculations - # Not recommended for a real implementation - initial_configurations_via_metalearning=0, - smac_scenario_args={'runcount_limit': 1}, - ) - automl.fit(X_train, y_train, dataset_name='reuters') - - ############################################################################ - # Print the final ensemble constructed by auto-sklearn - # ==================================================== - - print(automl.show_models()) - - ############################################################################ - # Print statistics about the auto-sklearn run - # =========================================== - - # Print statistics about the auto-sklearn run such as number of - # iterations, number of models failed with a time out. - print(automl.sprint_statistics()) - - ############################################################################ - # Get the Score of the final ensemble - # =================================== - - predictions = automl.predict(X_test) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) +############################################################################ +# Data Loading +# ============ + +# Using reuters multilabel dataset -- https://www.openml.org/d/40594 +X, y = sklearn.datasets.fetch_openml(data_id=40594, return_X_y=True, as_frame=False) + +# fetch openml downloads a numpy array with TRUE/FALSE strings. Re-map it to +# integer dtype with ones and zeros +# This is to comply with Scikit-learn requirement: +# "Positive classes are indicated with 1 and negative classes with 0 or -1." +# More information on: https://scikit-learn.org/stable/modules/multiclass.html +y[y == 'TRUE'] = 1 +y[y == 'FALSE'] = 0 +y = y.astype(np.int) + +# Using type of target is a good way to make sure your data +# is properly formatted +print(f"type_of_target={type_of_target(y)}") + +X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + X, y, random_state=1 +) + +############################################################################ +# Building the classifier +# ======================= + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=60, + per_run_time_limit=30, + # Bellow two flags are provided to speed up calculations + # Not recommended for a real implementation + initial_configurations_via_metalearning=0, + smac_scenario_args={'runcount_limit': 1}, +) +automl.fit(X_train, y_train, dataset_name='reuters') + +############################################################################ +# Print the final ensemble constructed by auto-sklearn +# ==================================================== + +print(automl.show_models()) + +############################################################################ +# Print statistics about the auto-sklearn run +# =========================================== + +# Print statistics about the auto-sklearn run such as number of +# iterations, number of models failed with a time out. +print(automl.sprint_statistics()) + +############################################################################ +# Get the Score of the final ensemble +# =================================== + +predictions = automl.predict(X_test) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) diff --git a/examples/20_basic/example_regression.py b/examples/20_basic/example_regression.py index 9cf83f5e3c..5c879006c7 100644 --- a/examples/20_basic/example_regression.py +++ b/examples/20_basic/example_regression.py @@ -13,37 +13,36 @@ import autosklearn.regression -if __name__ == "__main__": - ############################################################################ - # Data Loading - # ============ +############################################################################ +# Data Loading +# ============ - X, y = sklearn.datasets.load_boston(return_X_y=True) +X, y = sklearn.datasets.load_boston(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) - ############################################################################ - # Build and fit a regressor - # ========================= +############################################################################ +# Build and fit a regressor +# ========================= - automl = autosklearn.regression.AutoSklearnRegressor( - time_left_for_this_task=120, - per_run_time_limit=30, - tmp_folder='/tmp/autosklearn_regression_example_tmp', - output_folder='/tmp/autosklearn_regression_example_out', - ) - automl.fit(X_train, y_train, dataset_name='boston') +automl = autosklearn.regression.AutoSklearnRegressor( + time_left_for_this_task=120, + per_run_time_limit=30, + tmp_folder='/tmp/autosklearn_regression_example_tmp', + output_folder='/tmp/autosklearn_regression_example_out', +) +automl.fit(X_train, y_train, dataset_name='boston') - ############################################################################ - # Print the final ensemble constructed by auto-sklearn - # ==================================================== +############################################################################ +# Print the final ensemble constructed by auto-sklearn +# ==================================================== - print(automl.show_models()) +print(automl.show_models()) - ########################################################################### - # Get the Score of the final ensemble - # =================================== +########################################################################### +# Get the Score of the final ensemble +# =================================== - predictions = automl.predict(X_test) - print("R2 score:", sklearn.metrics.r2_score(y_test, predictions)) +predictions = automl.predict(X_test) +print("R2 score:", sklearn.metrics.r2_score(y_test, predictions)) diff --git a/examples/40_advanced/example_calc_multiple_metrics.py b/examples/40_advanced/example_calc_multiple_metrics.py index 7139d68832..c7a4e78503 100644 --- a/examples/40_advanced/example_calc_multiple_metrics.py +++ b/examples/40_advanced/example_calc_multiple_metrics.py @@ -11,13 +11,18 @@ """ import autosklearn.classification -import custom_metrics +import numpy as np import pandas as pd import sklearn.datasets import sklearn.metrics from autosklearn.metrics import balanced_accuracy, precision, recall, f1 +def error(solution, prediction): + # custom function defining error + return np.mean(solution != prediction) + + def get_metric_result(cv_results): results = pd.DataFrame.from_dict(cv_results) results = results[results['status'] == "Success"] @@ -26,41 +31,40 @@ def get_metric_result(cv_results): return results[cols] -if __name__ == "__main__": - ############################################################################ - # Data Loading - # ============ +############################################################################ +# Data Loading +# ============ - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1) +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) - ############################################################################ - # Build and fit a classifier - # ========================== +############################################################################ +# Build and fit a classifier +# ========================== - error_rate = autosklearn.metrics.make_scorer( - name='custom_error', - score_func=custom_metrics.error, - optimum=0, - greater_is_better=False, - needs_proba=False, - needs_threshold=False - ) - cls = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=120, - per_run_time_limit=30, - scoring_functions=[balanced_accuracy, precision, recall, f1, error_rate] - ) - cls.fit(X_train, y_train, X_test, y_test) +error_rate = autosklearn.metrics.make_scorer( + name='custom_error', + score_func=error, + optimum=0, + greater_is_better=False, + needs_proba=False, + needs_threshold=False +) +cls = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=120, + per_run_time_limit=30, + scoring_functions=[balanced_accuracy, precision, recall, f1, error_rate] +) +cls.fit(X_train, y_train, X_test, y_test) - ########################################################################### - # Get the Score of the final ensemble - # =================================== +########################################################################### +# Get the Score of the final ensemble +# =================================== - predictions = cls.predict(X_test) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) +predictions = cls.predict(X_test) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) - print("#" * 80) - print("Metric results") - print(get_metric_result(cls.cv_results_).to_string(index=False)) +print("#" * 80) +print("Metric results") +print(get_metric_result(cls.cv_results_).to_string(index=False)) diff --git a/examples/40_advanced/example_feature_types.py b/examples/40_advanced/example_feature_types.py index f36a6095fe..5f1c83ddcf 100644 --- a/examples/40_advanced/example_feature_types.py +++ b/examples/40_advanced/example_feature_types.py @@ -21,41 +21,40 @@ import autosklearn.classification -if __name__ == "__main__": - ############################################################################ - # Data Loading - # ============ - # Load Australian dataset from https://www.openml.org/d/40981 - bunch = data = sklearn.datasets.fetch_openml(data_id=40981, as_frame=True) - y = bunch['target'].to_numpy() - X = bunch['data'].to_numpy(np.float) - - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1) - - # Auto-sklearn can automatically recognize categorical/numerical data from a pandas - # DataFrame. This example highlights how the user can provide the feature types, - # when using numpy arrays, as there is no per-column dtype in this case. - # feat_type is a list that tags each column from a DataFrame/ numpy array / list - # with the case-insensitive string categorical or numerical, accordingly. - feat_type = ['Categorical' if x.name == 'category' else 'Numerical' for x in bunch['data'].dtypes] - - ############################################################################ - # Build and fit a classifier - # ========================== - - cls = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=30, - # Bellow two flags are provided to speed up calculations - # Not recommended for a real implementation - initial_configurations_via_metalearning=0, - smac_scenario_args={'runcount_limit': 1}, - ) - cls.fit(X_train, y_train, X_test, y_test, feat_type=feat_type) - - ########################################################################### - # Get the Score of the final ensemble - # =================================== - - predictions = cls.predict(X_test) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) +############################################################################ +# Data Loading +# ============ +# Load Australian dataset from https://www.openml.org/d/40981 +bunch = data = sklearn.datasets.fetch_openml(data_id=40981, as_frame=True) +y = bunch['target'].to_numpy() +X = bunch['data'].to_numpy(np.float) + +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) + +# Auto-sklearn can automatically recognize categorical/numerical data from a pandas +# DataFrame. This example highlights how the user can provide the feature types, +# when using numpy arrays, as there is no per-column dtype in this case. +# feat_type is a list that tags each column from a DataFrame/ numpy array / list +# with the case-insensitive string categorical or numerical, accordingly. +feat_type = ['Categorical' if x.name == 'category' else 'Numerical' for x in bunch['data'].dtypes] + +############################################################################ +# Build and fit a classifier +# ========================== + +cls = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=30, + # Bellow two flags are provided to speed up calculations + # Not recommended for a real implementation + initial_configurations_via_metalearning=0, + smac_scenario_args={'runcount_limit': 1}, +) +cls.fit(X_train, y_train, X_test, y_test, feat_type=feat_type) + +########################################################################### +# Get the Score of the final ensemble +# =================================== + +predictions = cls.predict(X_test) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) diff --git a/examples/40_advanced/example_get_pipeline_components.py b/examples/40_advanced/example_get_pipeline_components.py index 242d4a6a7d..0577e0beba 100644 --- a/examples/40_advanced/example_get_pipeline_components.py +++ b/examples/40_advanced/example_get_pipeline_components.py @@ -20,172 +20,171 @@ import autosklearn.classification -if __name__ == "__main__": - ############################################################################ - # Data Loading - # ============ - - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1) - - ############################################################################ - # Build and fit the classifier - # ============================ - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=30, - per_run_time_limit=10, - disable_evaluator_output=False, - # To simplify querying the models in the final ensemble, we - # restrict auto-sklearn to use only pca as a preprocessor - include_preprocessors=['pca'], - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - ############################################################################ - # Predict using the model - # ======================= - - predictions = automl.predict(X_test) - print("Accuracy score:{}".format( - sklearn.metrics.accuracy_score(y_test, predictions)) - ) - - - ############################################################################ - # Report the models found by Auto-Sklearn - # ======================================= - # - # Auto-sklearn uses - # `Ensemble Selection `_ - # to construct ensembles in a post-hoc fashion. The ensemble is a linear - # weighting of all models constructed during the hyperparameter optimization. - # This prints the final ensemble. It is a list of tuples, each tuple being - # the model weight in the ensemble and the model itself. - - print(automl.show_models()) - - ########################################################################### - # Report statistics about the search - # ================================== - # - # Print statistics about the auto-sklearn run such as number of - # iterations, number of models failed with a time out etc. - print(automl.sprint_statistics()) - - ############################################################################ - # Detailed statistics about the search - part 1 - # ============================================= - # - # Auto-sklearn also keeps detailed statistics of the hyperparameter - # optimization procedurce, which are stored in a so-called - # `run history `_. - - print(automl.automl_.runhistory_) - - ############################################################################ - # Runs are stored inside an ``OrderedDict`` called ``data``: - - print(len(automl.automl_.runhistory_.data)) - - ############################################################################ - # Let's iterative over all entries - - for run_key in automl.automl_.runhistory_.data: - print('#########') - print(run_key) - print(automl.automl_.runhistory_.data[run_key]) - - ############################################################################ - # and have a detailed look at one entry: - - run_key = list(automl.automl_.runhistory_.data.keys())[0] - run_value = automl.automl_.runhistory_.data[run_key] - - ############################################################################ - # The ``run_key`` contains all information describing a run: - - print("Configuration ID:", run_key.config_id) - print("Instance:", run_key.instance_id) - print("Seed:", run_key.seed) - print("Budget:", run_key.budget) - - ############################################################################ - # and the configuration can be looked up in the run history as well: - - print(automl.automl_.runhistory_.ids_config[run_key.config_id]) - - ############################################################################ - # The only other important entry is the budget in case you are using - # auto-sklearn with - # `successive halving <../60_search/example_successive_halving.html>`_. - # The remaining parts of the key can be ignored for auto-sklearn and are - # only there because the underlying optimizer, SMAC, can handle more general - # problems, too. - - ############################################################################ - # The ``run_value`` contains all output from running the configuration: - - print("Cost:", run_value.cost) - print("Time:", run_value.time) - print("Status:", run_value.status) - print("Additional information:", run_value.additional_info) - print("Start time:", run_value.starttime) - print("End time", run_value.endtime) - - ############################################################################ - # Cost is basically the same as a loss. In case the metric to optimize for - # should be maximized, it is internally transformed into a minimization - # metric. Additionally, the status type gives information on whether the run - # was successful, while the additional information's most interesting entry - # is the internal training loss. Furthermore, there is detailed information - # on the runtime available. - - ############################################################################ - # As an example, let's find the best configuration evaluated. As - # Auto-sklearn solves a minimization problem internally, we need to look - # for the entry with the lowest loss: - - losses_and_configurations = [ - (run_value.cost, run_key.config_id) - for run_key, run_value in automl.automl_.runhistory_.data.items() - ] - losses_and_configurations.sort() - print("Lowest loss:", losses_and_configurations[0][0]) - print( - "Best configuration:", - automl.automl_.runhistory_.ids_config[losses_and_configurations[0][1]] - ) - - ############################################################################ - # Detailed statistics about the search - part 2 - # ============================================= - # - # To maintain compatibility with scikit-learn, Auto-sklearn gives the - # same data as - # `cv_results_ `_. - - print(automl.cv_results_) - - ############################################################################ - # Inspect the components of the best model - # ======================================== - # - # Iterate over the components of the model and print - # The explained variance ratio per stage - for i, (weight, pipeline) in enumerate(automl.get_models_with_weights()): - for stage_name, component in pipeline.named_steps.items(): - if 'preprocessor' in stage_name: - print( - "The {}th pipeline has a explained variance of {}".format( - i, - # The component is an instance of AutoSklearnChoice. - # Access the sklearn object via the choice attribute - # We want the explained variance attributed of - # each principal component - component.choice.preprocessor.explained_variance_ratio_ - ) +############################################################################ +# Data Loading +# ============ + +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) + +############################################################################ +# Build and fit the classifier +# ============================ + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=30, + per_run_time_limit=10, + disable_evaluator_output=False, + # To simplify querying the models in the final ensemble, we + # restrict auto-sklearn to use only pca as a preprocessor + include_preprocessors=['pca'], +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +############################################################################ +# Predict using the model +# ======================= + +predictions = automl.predict(X_test) +print("Accuracy score:{}".format( + sklearn.metrics.accuracy_score(y_test, predictions)) +) + + +############################################################################ +# Report the models found by Auto-Sklearn +# ======================================= +# +# Auto-sklearn uses +# `Ensemble Selection `_ +# to construct ensembles in a post-hoc fashion. The ensemble is a linear +# weighting of all models constructed during the hyperparameter optimization. +# This prints the final ensemble. It is a list of tuples, each tuple being +# the model weight in the ensemble and the model itself. + +print(automl.show_models()) + +########################################################################### +# Report statistics about the search +# ================================== +# +# Print statistics about the auto-sklearn run such as number of +# iterations, number of models failed with a time out etc. +print(automl.sprint_statistics()) + +############################################################################ +# Detailed statistics about the search - part 1 +# ============================================= +# +# Auto-sklearn also keeps detailed statistics of the hyperparameter +# optimization procedurce, which are stored in a so-called +# `run history `_. + +print(automl.automl_.runhistory_) + +############################################################################ +# Runs are stored inside an ``OrderedDict`` called ``data``: + +print(len(automl.automl_.runhistory_.data)) + +############################################################################ +# Let's iterative over all entries + +for run_key in automl.automl_.runhistory_.data: + print('#########') + print(run_key) + print(automl.automl_.runhistory_.data[run_key]) + +############################################################################ +# and have a detailed look at one entry: + +run_key = list(automl.automl_.runhistory_.data.keys())[0] +run_value = automl.automl_.runhistory_.data[run_key] + +############################################################################ +# The ``run_key`` contains all information describing a run: + +print("Configuration ID:", run_key.config_id) +print("Instance:", run_key.instance_id) +print("Seed:", run_key.seed) +print("Budget:", run_key.budget) + +############################################################################ +# and the configuration can be looked up in the run history as well: + +print(automl.automl_.runhistory_.ids_config[run_key.config_id]) + +############################################################################ +# The only other important entry is the budget in case you are using +# auto-sklearn with +# `successive halving <../60_search/example_successive_halving.html>`_. +# The remaining parts of the key can be ignored for auto-sklearn and are +# only there because the underlying optimizer, SMAC, can handle more general +# problems, too. + +############################################################################ +# The ``run_value`` contains all output from running the configuration: + +print("Cost:", run_value.cost) +print("Time:", run_value.time) +print("Status:", run_value.status) +print("Additional information:", run_value.additional_info) +print("Start time:", run_value.starttime) +print("End time", run_value.endtime) + +############################################################################ +# Cost is basically the same as a loss. In case the metric to optimize for +# should be maximized, it is internally transformed into a minimization +# metric. Additionally, the status type gives information on whether the run +# was successful, while the additional information's most interesting entry +# is the internal training loss. Furthermore, there is detailed information +# on the runtime available. + +############################################################################ +# As an example, let's find the best configuration evaluated. As +# Auto-sklearn solves a minimization problem internally, we need to look +# for the entry with the lowest loss: + +losses_and_configurations = [ + (run_value.cost, run_key.config_id) + for run_key, run_value in automl.automl_.runhistory_.data.items() +] +losses_and_configurations.sort() +print("Lowest loss:", losses_and_configurations[0][0]) +print( + "Best configuration:", + automl.automl_.runhistory_.ids_config[losses_and_configurations[0][1]] +) + +############################################################################ +# Detailed statistics about the search - part 2 +# ============================================= +# +# To maintain compatibility with scikit-learn, Auto-sklearn gives the +# same data as +# `cv_results_ `_. + +print(automl.cv_results_) + +############################################################################ +# Inspect the components of the best model +# ======================================== +# +# Iterate over the components of the model and print +# The explained variance ratio per stage +for i, (weight, pipeline) in enumerate(automl.get_models_with_weights()): + for stage_name, component in pipeline.named_steps.items(): + if 'preprocessor' in stage_name: + print( + "The {}th pipeline has a explained variance of {}".format( + i, + # The component is an instance of AutoSklearnChoice. + # Access the sklearn object via the choice attribute + # We want the explained variance attributed of + # each principal component + component.choice.preprocessor.explained_variance_ratio_ ) + ) diff --git a/examples/40_advanced/example_metrics.py b/examples/40_advanced/example_metrics.py index f58ac0f73a..4fbc56f61a 100644 --- a/examples/40_advanced/example_metrics.py +++ b/examples/40_advanced/example_metrics.py @@ -22,161 +22,171 @@ ############################################################################ -# Data Loading -# ============ -# The custom metrics must be in a separate module to be usable together with -# Auto-sklearn. We also print the content of the module below with -# ``inspect`` to keep the example self-contained. +# Custom Metrics +# ============== +def accuracy(solution, prediction): + # custom function defining accuracy + return np.mean(solution == prediction) -import custom_metrics +def error(solution, prediction): + # custom function defining error + return np.mean(solution != prediction) -if __name__ == "__main__": - import inspect - print(inspect.getsource(custom_metrics)) +def accuracy_wk(solution, prediction, dummy): + # custom function defining accuracy and accepting an additional argument + assert dummy is None + return np.mean(solution == prediction) - ############################################################################ - # Data Loading - # ============ - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1) +def error_wk(solution, prediction, dummy): + # custom function defining error and accepting an additional argument + assert dummy is None + return np.mean(solution != prediction) - ############################################################################ - # Print a list of available metrics - # ================================= - print("Available CLASSIFICATION metrics autosklearn.metrics.*:") - print("\t*" + "\n\t*".join(autosklearn.metrics.CLASSIFICATION_METRICS)) +############################################################################ +# Data Loading +# ============ - print("Available REGRESSION autosklearn.metrics.*:") - print("\t*" + "\n\t*".join(autosklearn.metrics.REGRESSION_METRICS)) +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) - ############################################################################ - # First example: Use predefined accuracy metric - # ============================================= +############################################################################ +# Print a list of available metrics +# ================================= - print("#"*80) - print("Use predefined accuracy metric") - cls = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=60, - per_run_time_limit=30, - seed=1, - metric=autosklearn.metrics.accuracy, - ) - cls.fit(X_train, y_train) - - predictions = cls.predict(X_test) - print("Accuracy score {:g} using {:s}". - format(sklearn.metrics.accuracy_score(y_test, predictions), - cls.automl_._metric.name)) - - ############################################################################ - # Second example: Use own accuracy metric - # ======================================= - - print("#"*80) - print("Use self defined accuracy metric") - accuracy_scorer = autosklearn.metrics.make_scorer( - name="accu", - score_func=custom_metrics.accuracy, - optimum=1, - greater_is_better=True, - needs_proba=False, - needs_threshold=False, - ) - cls = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=60, - per_run_time_limit=30, - seed=1, - metric=accuracy_scorer, - ) - cls.fit(X_train, y_train) - - predictions = cls.predict(X_test) - print("Accuracy score {:g} using {:s}". - format(sklearn.metrics.accuracy_score(y_test, predictions), - cls.automl_._metric.name)) - - print("#"*80) - print("Use self defined error metric") - error_rate = autosklearn.metrics.make_scorer( - name='error', - score_func=custom_metrics.error, - optimum=0, - greater_is_better=False, - needs_proba=False, - needs_threshold=False - ) - cls = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=60, - per_run_time_limit=30, - seed=1, - metric=error_rate, - ) - cls.fit(X_train, y_train) - - cls.predictions = cls.predict(X_test) - print("Error rate {:g} using {:s}". - format(error_rate(y_test, predictions), - cls.automl_._metric.name)) - - ############################################################################ - # Third example: Use own accuracy metric with additional argument - # =============================================================== - - print("#"*80) - print("Use self defined accuracy with additional argument") - accuracy_scorer = autosklearn.metrics.make_scorer( - name="accu_add", - score_func=custom_metrics.accuracy_wk, - optimum=1, - greater_is_better=True, - needs_proba=False, - needs_threshold=False, - dummy=None, - ) - cls = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=60, - per_run_time_limit=30, - seed=1, - metric=accuracy_scorer - ) - cls.fit(X_train, y_train) - - predictions = cls.predict(X_test) - print( - "Accuracy score {:g} using {:s}".format( - sklearn.metrics.accuracy_score(y_test, predictions), - cls.automl_._metric.name - ) - ) +print("Available CLASSIFICATION metrics autosklearn.metrics.*:") +print("\t*" + "\n\t*".join(autosklearn.metrics.CLASSIFICATION_METRICS)) - print("#"*80) - print("Use self defined error with additional argument") - error_rate = autosklearn.metrics.make_scorer( - name="error_add", - score_func=custom_metrics.error_wk, - optimum=0, - greater_is_better=True, - needs_proba=False, - needs_threshold=False, - dummy=None, - ) - cls = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=60, - per_run_time_limit=30, - seed=1, - metric=error_rate, +print("Available REGRESSION autosklearn.metrics.*:") +print("\t*" + "\n\t*".join(autosklearn.metrics.REGRESSION_METRICS)) + +############################################################################ +# First example: Use predefined accuracy metric +# ============================================= + +print("#"*80) +print("Use predefined accuracy metric") +cls = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=60, + per_run_time_limit=30, + seed=1, + metric=autosklearn.metrics.accuracy, +) +cls.fit(X_train, y_train) + +predictions = cls.predict(X_test) +print("Accuracy score {:g} using {:s}". + format(sklearn.metrics.accuracy_score(y_test, predictions), + cls.automl_._metric.name)) + +############################################################################ +# Second example: Use own accuracy metric +# ======================================= + +print("#"*80) +print("Use self defined accuracy metric") +accuracy_scorer = autosklearn.metrics.make_scorer( + name="accu", + score_func=accuracy, + optimum=1, + greater_is_better=True, + needs_proba=False, + needs_threshold=False, +) +cls = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=60, + per_run_time_limit=30, + seed=1, + metric=accuracy_scorer, +) +cls.fit(X_train, y_train) + +predictions = cls.predict(X_test) +print("Accuracy score {:g} using {:s}". + format(sklearn.metrics.accuracy_score(y_test, predictions), + cls.automl_._metric.name)) + +print("#"*80) +print("Use self defined error metric") +error_rate = autosklearn.metrics.make_scorer( + name='error', + score_func=error, + optimum=0, + greater_is_better=False, + needs_proba=False, + needs_threshold=False +) +cls = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=60, + per_run_time_limit=30, + seed=1, + metric=error_rate, +) +cls.fit(X_train, y_train) + +cls.predictions = cls.predict(X_test) +print("Error rate {:g} using {:s}". + format(error_rate(y_test, predictions), + cls.automl_._metric.name)) + +############################################################################ +# Third example: Use own accuracy metric with additional argument +# =============================================================== + +print("#"*80) +print("Use self defined accuracy with additional argument") +accuracy_scorer = autosklearn.metrics.make_scorer( + name="accu_add", + score_func=accuracy_wk, + optimum=1, + greater_is_better=True, + needs_proba=False, + needs_threshold=False, + dummy=None, +) +cls = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=60, + per_run_time_limit=30, + seed=1, + metric=accuracy_scorer +) +cls.fit(X_train, y_train) + +predictions = cls.predict(X_test) +print( + "Accuracy score {:g} using {:s}".format( + sklearn.metrics.accuracy_score(y_test, predictions), + cls.automl_._metric.name ) - cls.fit(X_train, y_train) - - predictions = cls.predict(X_test) - print( - "Error rate {:g} using {:s}".format( - error_rate(y_test, predictions), - cls.automl_._metric.name - ) +) + +print("#"*80) +print("Use self defined error with additional argument") +error_rate = autosklearn.metrics.make_scorer( + name="error_add", + score_func=error_wk, + optimum=0, + greater_is_better=True, + needs_proba=False, + needs_threshold=False, + dummy=None, +) +cls = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=60, + per_run_time_limit=30, + seed=1, + metric=error_rate, +) +cls.fit(X_train, y_train) + +predictions = cls.predict(X_test) +print( + "Error rate {:g} using {:s}".format( + error_rate(y_test, predictions), + cls.automl_._metric.name ) +) diff --git a/examples/40_advanced/example_pandas_train_test.py b/examples/40_advanced/example_pandas_train_test.py index ba45523135..77498df0ab 100644 --- a/examples/40_advanced/example_pandas_train_test.py +++ b/examples/40_advanced/example_pandas_train_test.py @@ -59,97 +59,96 @@ def get_runhistory_models_performance(automl): return pd.DataFrame(performance_list) -if __name__ == "__main__": - ############################################################################ - # Data Loading - # ============ - - # Using Australian dataset https://www.openml.org/d/40981. - # This example will use the command fetch_openml, which will - # download a properly formatted dataframe if you use as_frame=True. - # For demonstration purposes, we will download a numpy array using - # as_frame=False, and manually creating the pandas DataFrame - X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=False) - - # bool and category will be automatically encoded. - # Targets for classification are also automatically encoded - # If using fetch_openml, data is already properly encoded, below - # is an example for user reference - X = pd.DataFrame( - data=X, - columns=['A' + str(i) for i in range(1, 15)] - ) - desired_boolean_columns = ['A1'] - desired_categorical_columns = ['A4', 'A5', 'A6', 'A8', 'A9', 'A11', 'A12'] - desired_numerical_columns = ['A2', 'A3', 'A7', 'A10', 'A13', 'A14'] - for column in X.columns: - if column in desired_boolean_columns: - X[column] = X[column].astype('bool') - elif column in desired_categorical_columns: - X[column] = X[column].astype('category') - else: - X[column] = pd.to_numeric(X[column]) - - y = pd.DataFrame(y, dtype='category') - - X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( - X, y, test_size=0.5, random_state=3 - ) - print(X.dtypes) - - ############################################################################ - # Build and fit a classifier - # ========================== - - cls = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=120, - per_run_time_limit=30, - ) - cls.fit(X_train, y_train, X_test, y_test) - - ########################################################################### - # Get the Score of the final ensemble - # =================================== - - predictions = cls.predict(X_test) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) - - ############################################################################ - # Plot the ensemble performance - # =================================== - - ensemble_performance_frame = pd.DataFrame(cls.automl_.ensemble_performance_history) - best_values = pd.Series({'ensemble_optimization_score': -np.inf, - 'ensemble_test_score': -np.inf}) - for idx in ensemble_performance_frame.index: - if ( - ensemble_performance_frame.loc[idx, 'ensemble_optimization_score'] - > best_values['ensemble_optimization_score'] - ): - best_values = ensemble_performance_frame.loc[idx] - ensemble_performance_frame.loc[idx] = best_values - - individual_performance_frame = get_runhistory_models_performance(cls) - best_values = pd.Series({'single_best_optimization_score': -np.inf, - 'single_best_test_score': -np.inf, - 'single_best_train_score': -np.inf}) - for idx in individual_performance_frame.index: - if ( - individual_performance_frame.loc[idx, 'single_best_optimization_score'] - > best_values['single_best_optimization_score'] - ): - best_values = individual_performance_frame.loc[idx] - individual_performance_frame.loc[idx] = best_values - - pd.merge( - ensemble_performance_frame, - individual_performance_frame, - on="Timestamp", how='outer' - ).sort_values('Timestamp').fillna(method='ffill').plot( - x='Timestamp', - kind='line', - legend=True, - title='Auto-sklearn accuracy over time', - grid=True, - ) - plt.show() +############################################################################ +# Data Loading +# ============ + +# Using Australian dataset https://www.openml.org/d/40981. +# This example will use the command fetch_openml, which will +# download a properly formatted dataframe if you use as_frame=True. +# For demonstration purposes, we will download a numpy array using +# as_frame=False, and manually creating the pandas DataFrame +X, y = sklearn.datasets.fetch_openml(data_id=40981, return_X_y=True, as_frame=False) + +# bool and category will be automatically encoded. +# Targets for classification are also automatically encoded +# If using fetch_openml, data is already properly encoded, below +# is an example for user reference +X = pd.DataFrame( + data=X, + columns=['A' + str(i) for i in range(1, 15)] +) +desired_boolean_columns = ['A1'] +desired_categorical_columns = ['A4', 'A5', 'A6', 'A8', 'A9', 'A11', 'A12'] +desired_numerical_columns = ['A2', 'A3', 'A7', 'A10', 'A13', 'A14'] +for column in X.columns: + if column in desired_boolean_columns: + X[column] = X[column].astype('bool') + elif column in desired_categorical_columns: + X[column] = X[column].astype('category') + else: + X[column] = pd.to_numeric(X[column]) + +y = pd.DataFrame(y, dtype='category') + +X_train, X_test, y_train, y_test = sklearn.model_selection.train_test_split( + X, y, test_size=0.5, random_state=3 +) +print(X.dtypes) + +############################################################################ +# Build and fit a classifier +# ========================== + +cls = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=120, + per_run_time_limit=30, +) +cls.fit(X_train, y_train, X_test, y_test) + +########################################################################### +# Get the Score of the final ensemble +# =================================== + +predictions = cls.predict(X_test) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) + +############################################################################ +# Plot the ensemble performance +# =================================== + +ensemble_performance_frame = pd.DataFrame(cls.automl_.ensemble_performance_history) +best_values = pd.Series({'ensemble_optimization_score': -np.inf, + 'ensemble_test_score': -np.inf}) +for idx in ensemble_performance_frame.index: + if ( + ensemble_performance_frame.loc[idx, 'ensemble_optimization_score'] + > best_values['ensemble_optimization_score'] + ): + best_values = ensemble_performance_frame.loc[idx] + ensemble_performance_frame.loc[idx] = best_values + +individual_performance_frame = get_runhistory_models_performance(cls) +best_values = pd.Series({'single_best_optimization_score': -np.inf, + 'single_best_test_score': -np.inf, + 'single_best_train_score': -np.inf}) +for idx in individual_performance_frame.index: + if ( + individual_performance_frame.loc[idx, 'single_best_optimization_score'] + > best_values['single_best_optimization_score'] + ): + best_values = individual_performance_frame.loc[idx] + individual_performance_frame.loc[idx] = best_values + +pd.merge( + ensemble_performance_frame, + individual_performance_frame, + on="Timestamp", how='outer' +).sort_values('Timestamp').fillna(method='ffill').plot( + x='Timestamp', + kind='line', + legend=True, + title='Auto-sklearn accuracy over time', + grid=True, +) +plt.show() diff --git a/examples/40_advanced/example_resampling.py b/examples/40_advanced/example_resampling.py index 3bcf040f7c..1654edd6cc 100644 --- a/examples/40_advanced/example_resampling.py +++ b/examples/40_advanced/example_resampling.py @@ -17,114 +17,113 @@ import autosklearn.classification -if __name__ == "__main__": - ############################################################################ - # Data Loading - # ============ - - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1) - - ############################################################################ - # Holdout - # ======= - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=120, - per_run_time_limit=30, - tmp_folder='/tmp/autosklearn_resampling_example_tmp', - output_folder='/tmp/autosklearn_resampling_example_out', - disable_evaluator_output=False, - # 'holdout' with 'train_size'=0.67 is the default argument setting - # for AutoSklearnClassifier. It is explicitly specified in this example - # for demonstrational purpose. - resampling_strategy='holdout', - resampling_strategy_arguments={'train_size': 0.67}, - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - ############################################################################ - # Get the Score of the final ensemble - # =================================== - - predictions = automl.predict(X_test) - print("Accuracy score holdout: ", sklearn.metrics.accuracy_score(y_test, predictions)) - - - ############################################################################ - # Cross-validation - # ================ - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=120, - per_run_time_limit=30, - tmp_folder='/tmp/autosklearn_resampling_example_tmp', - output_folder='/tmp/autosklearn_resampling_example_out', - disable_evaluator_output=False, - resampling_strategy='cv', - resampling_strategy_arguments={'folds': 5}, - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - # One can use models trained during cross-validation directly to predict - # for unseen data. For this, all k models trained during k-fold - # cross-validation are considered as a single soft-voting ensemble inside - # the ensemble constructed with ensemble selection. - print('Before re-fit') - predictions = automl.predict(X_test) - print("Accuracy score CV", sklearn.metrics.accuracy_score(y_test, predictions)) - - ############################################################################ - # Perform a refit - # =============== - # During fit(), models are fit on individual cross-validation folds. To use - # all available data, we call refit() which trains all models in the - # final ensemble on the whole dataset. - print('After re-fit') - automl.refit(X_train.copy(), y_train.copy()) - predictions = automl.predict(X_test) - print("Accuracy score CV", sklearn.metrics.accuracy_score(y_test, predictions)) - - ############################################################################ - # scikit-learn splitter objects - # ============================= - # It is also possible to use - # `scikit-learn's splitter classes `_ to further customize the outputs. In case one needs to have 100% control over the - # splitting, it is possible to use - # `scikit-learn's PredefinedSplit `_. - - ############################################################################ - # Below is an example of using a predefined split. We split the training - # data by the first feature. In practice, one would use a splitting according - # to the use case at hand. - - resampling_strategy = sklearn.model_selection.PredefinedSplit - resampling_strategy_arguments = {'test_fold': np.where(X_train[:, 0] < np.mean(X_train[:, 0]))[0]} - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=120, - per_run_time_limit=30, - tmp_folder='/tmp/autosklearn_resampling_example_tmp', - output_folder='/tmp/autosklearn_resampling_example_out', - disable_evaluator_output=False, - resampling_strategy=resampling_strategy, - resampling_strategy_arguments=resampling_strategy_arguments, - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - ############################################################################ - # For custom resampling strategies (i.e. resampling strategies that are not - # defined as strings by Auto-sklearn) it is necessary to perform a refit: - automl.refit(X_train, y_train) - - ############################################################################ - # Get the Score of the final ensemble (again) - # =========================================== - # - # Obviously, this score is pretty bad as we "destroyed" the dataset by - # splitting it on the first feature. - predictions = automl.predict(X_test) - print("Accuracy score custom split", sklearn.metrics.accuracy_score(y_test, predictions)) +############################################################################ +# Data Loading +# ============ + +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) + +############################################################################ +# Holdout +# ======= + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=120, + per_run_time_limit=30, + tmp_folder='/tmp/autosklearn_resampling_example_tmp', + output_folder='/tmp/autosklearn_resampling_example_out', + disable_evaluator_output=False, + # 'holdout' with 'train_size'=0.67 is the default argument setting + # for AutoSklearnClassifier. It is explicitly specified in this example + # for demonstrational purpose. + resampling_strategy='holdout', + resampling_strategy_arguments={'train_size': 0.67}, +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +############################################################################ +# Get the Score of the final ensemble +# =================================== + +predictions = automl.predict(X_test) +print("Accuracy score holdout: ", sklearn.metrics.accuracy_score(y_test, predictions)) + + +############################################################################ +# Cross-validation +# ================ + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=120, + per_run_time_limit=30, + tmp_folder='/tmp/autosklearn_resampling_example_tmp', + output_folder='/tmp/autosklearn_resampling_example_out', + disable_evaluator_output=False, + resampling_strategy='cv', + resampling_strategy_arguments={'folds': 5}, +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +# One can use models trained during cross-validation directly to predict +# for unseen data. For this, all k models trained during k-fold +# cross-validation are considered as a single soft-voting ensemble inside +# the ensemble constructed with ensemble selection. +print('Before re-fit') +predictions = automl.predict(X_test) +print("Accuracy score CV", sklearn.metrics.accuracy_score(y_test, predictions)) + +############################################################################ +# Perform a refit +# =============== +# During fit(), models are fit on individual cross-validation folds. To use +# all available data, we call refit() which trains all models in the +# final ensemble on the whole dataset. +print('After re-fit') +automl.refit(X_train.copy(), y_train.copy()) +predictions = automl.predict(X_test) +print("Accuracy score CV", sklearn.metrics.accuracy_score(y_test, predictions)) + +############################################################################ +# scikit-learn splitter objects +# ============================= +# It is also possible to use +# `scikit-learn's splitter classes `_ to further customize the outputs. In case one needs to have 100% control over the +# splitting, it is possible to use +# `scikit-learn's PredefinedSplit `_. + +############################################################################ +# Below is an example of using a predefined split. We split the training +# data by the first feature. In practice, one would use a splitting according +# to the use case at hand. + +resampling_strategy = sklearn.model_selection.PredefinedSplit +resampling_strategy_arguments = {'test_fold': np.where(X_train[:, 0] < np.mean(X_train[:, 0]))[0]} + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=120, + per_run_time_limit=30, + tmp_folder='/tmp/autosklearn_resampling_example_tmp', + output_folder='/tmp/autosklearn_resampling_example_out', + disable_evaluator_output=False, + resampling_strategy=resampling_strategy, + resampling_strategy_arguments=resampling_strategy_arguments, +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +############################################################################ +# For custom resampling strategies (i.e. resampling strategies that are not +# defined as strings by Auto-sklearn) it is necessary to perform a refit: +automl.refit(X_train, y_train) + +############################################################################ +# Get the Score of the final ensemble (again) +# =========================================== +# +# Obviously, this score is pretty bad as we "destroyed" the dataset by +# splitting it on the first feature. +predictions = automl.predict(X_test) +print("Accuracy score custom split", sklearn.metrics.accuracy_score(y_test, predictions)) diff --git a/examples/60_search/example_random_search.py b/examples/60_search/example_random_search.py index e1596ecfd9..0624b6cdd8 100644 --- a/examples/60_search/example_random_search.py +++ b/examples/60_search/example_random_search.py @@ -22,117 +22,116 @@ import autosklearn.classification -if __name__ == "__main__": - ############################################################################ - # Data Loading - # ============ - - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1) +############################################################################ +# Data Loading +# ============ + +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) + + +############################################################################ +# Fit a classifier using ROAR +# =========================== +def get_roar_object_callback( + scenario_dict, + seed, + ta, + ta_kwargs, + metalearning_configurations, + n_jobs, + dask_client, +): + """Random online adaptive racing.""" + + if n_jobs > 1 or (dask_client and len(dask_client.nthreads()) > 1): + raise ValueError("Please make sure to guard the code invoking Auto-sklearn by " + "`if __name__ == '__main__'` and remove this exception.") + + scenario = Scenario(scenario_dict) + return ROAR( + scenario=scenario, + rng=seed, + tae_runner=ta, + tae_runner_kwargs=ta_kwargs, + run_id=seed, + dask_client=dask_client, + n_jobs=n_jobs, + ) - ############################################################################ - # Fit a classifier using ROAR - # =========================== - def get_roar_object_callback( +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=60, per_run_time_limit=15, + tmp_folder='/tmp/autosklearn_random_search_example_tmp', + output_folder='/tmp/autosklearn_random_search_example_out', + get_smac_object_callback=get_roar_object_callback, + initial_configurations_via_metalearning=0, +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +print('#' * 80) +print('Results for ROAR.') +# Print the final ensemble constructed by auto-sklearn via ROAR. +print(automl.show_models()) +predictions = automl.predict(X_test) +# Print statistics about the auto-sklearn run such as number of +# iterations, number of models failed with a time out. +print(automl.sprint_statistics()) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) + + +############################################################################ +# Fit a classifier using Random Search +# ==================================== +def get_random_search_object_callback( scenario_dict, seed, ta, ta_kwargs, metalearning_configurations, n_jobs, - dask_client, - ): - """Random online adaptive racing.""" - - if n_jobs > 1 or (dask_client and len(dask_client.nthreads()) > 1): - raise ValueError("Please make sure to guard the code invoking Auto-sklearn by " - "`if __name__ == '__main__'` and remove this exception.") - - scenario = Scenario(scenario_dict) - return ROAR( - scenario=scenario, - rng=seed, - tae_runner=ta, - tae_runner_kwargs=ta_kwargs, - run_id=seed, - dask_client=dask_client, - n_jobs=n_jobs, - ) - - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=60, per_run_time_limit=15, - tmp_folder='/tmp/autosklearn_random_search_example_tmp', - output_folder='/tmp/autosklearn_random_search_example_out', - get_smac_object_callback=get_roar_object_callback, - initial_configurations_via_metalearning=0, - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - print('#' * 80) - print('Results for ROAR.') - # Print the final ensemble constructed by auto-sklearn via ROAR. - print(automl.show_models()) - predictions = automl.predict(X_test) - # Print statistics about the auto-sklearn run such as number of - # iterations, number of models failed with a time out. - print(automl.sprint_statistics()) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) - - - ############################################################################ - # Fit a classifier using Random Search - # ==================================== - def get_random_search_object_callback( - scenario_dict, - seed, - ta, - ta_kwargs, - metalearning_configurations, - n_jobs, - dask_client - ): - """Random search.""" - - if n_jobs > 1 or (dask_client and len(dask_client.nthreads()) > 1): - raise ValueError("Please make sure to guard the code invoking Auto-sklearn by " - "`if __name__ == '__main__'` and remove this exception.") - - scenario_dict['minR'] = len(scenario_dict['instances']) - scenario_dict['initial_incumbent'] = 'RANDOM' - scenario = Scenario(scenario_dict) - return ROAR( - scenario=scenario, - rng=seed, - tae_runner=ta, - tae_runner_kwargs=ta_kwargs, - run_id=seed, - dask_client=dask_client, - n_jobs=n_jobs, - ) - - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=60, - per_run_time_limit=15, - tmp_folder='/tmp/autosklearn_random_search_example_tmp', - output_folder='/tmp/autosklearn_random_search_example_out', - get_smac_object_callback=get_random_search_object_callback, - initial_configurations_via_metalearning=0, + dask_client +): + """Random search.""" + + if n_jobs > 1 or (dask_client and len(dask_client.nthreads()) > 1): + raise ValueError("Please make sure to guard the code invoking Auto-sklearn by " + "`if __name__ == '__main__'` and remove this exception.") + + scenario_dict['minR'] = len(scenario_dict['instances']) + scenario_dict['initial_incumbent'] = 'RANDOM' + scenario = Scenario(scenario_dict) + return ROAR( + scenario=scenario, + rng=seed, + tae_runner=ta, + tae_runner_kwargs=ta_kwargs, + run_id=seed, + dask_client=dask_client, + n_jobs=n_jobs, ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - print('#' * 80) - print('Results for random search.') - # Print the final ensemble constructed by auto-sklearn via random search. - print(automl.show_models()) +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=60, + per_run_time_limit=15, + tmp_folder='/tmp/autosklearn_random_search_example_tmp', + output_folder='/tmp/autosklearn_random_search_example_out', + get_smac_object_callback=get_random_search_object_callback, + initial_configurations_via_metalearning=0, +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +print('#' * 80) +print('Results for random search.') + +# Print the final ensemble constructed by auto-sklearn via random search. +print(automl.show_models()) - # Print statistics about the auto-sklearn run such as number of - # iterations, number of models failed with a time out. - print(automl.sprint_statistics()) +# Print statistics about the auto-sklearn run such as number of +# iterations, number of models failed with a time out. +print(automl.sprint_statistics()) - predictions = automl.predict(X_test) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) +predictions = automl.predict(X_test) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) diff --git a/examples/60_search/example_sequential.py b/examples/60_search/example_sequential.py index 9dd6a6c992..6bbeb52fcc 100644 --- a/examples/60_search/example_sequential.py +++ b/examples/60_search/example_sequential.py @@ -16,46 +16,45 @@ import autosklearn.classification -if __name__ == "__main__": - ############################################################################ - # Data Loading - # ====================================== - - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1) - - ############################################################################ - # Build and fit the classifier - # ====================================== - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=120, - per_run_time_limit=30, - tmp_folder='/tmp/autosklearn_sequential_example_tmp', - output_folder='/tmp/autosklearn_sequential_example_out', - # Do not construct ensembles in parallel to avoid using more than one - # core at a time. The ensemble will be constructed after auto-sklearn - # finished fitting all machine learning models. - ensemble_size=0, - delete_tmp_folder_after_terminate=False, - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - # This call to fit_ensemble uses all models trained in the previous call - # to fit to build an ensemble which can be used with automl.predict() - automl.fit_ensemble(y_train, ensemble_size=50) - - ############################################################################ - # Print the final ensemble constructed by auto-sklearn - # ==================================================== - - print(automl.show_models()) - - ############################################################################ - # Get the Score of the final ensemble - # =================================== - - predictions = automl.predict(X_test) - print(automl.sprint_statistics()) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) +############################################################################ +# Data Loading +# ====================================== + +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) + +############################################################################ +# Build and fit the classifier +# ====================================== + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=120, + per_run_time_limit=30, + tmp_folder='/tmp/autosklearn_sequential_example_tmp', + output_folder='/tmp/autosklearn_sequential_example_out', + # Do not construct ensembles in parallel to avoid using more than one + # core at a time. The ensemble will be constructed after auto-sklearn + # finished fitting all machine learning models. + ensemble_size=0, + delete_tmp_folder_after_terminate=False, +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +# This call to fit_ensemble uses all models trained in the previous call +# to fit to build an ensemble which can be used with automl.predict() +automl.fit_ensemble(y_train, ensemble_size=50) + +############################################################################ +# Print the final ensemble constructed by auto-sklearn +# ==================================================== + +print(automl.show_models()) + +############################################################################ +# Get the Score of the final ensemble +# =================================== + +predictions = automl.predict(X_test) +print(automl.sprint_statistics()) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) diff --git a/examples/60_search/example_successive_halving.py b/examples/60_search/example_successive_halving.py index 67111fcc10..7fd2924e52 100644 --- a/examples/60_search/example_successive_halving.py +++ b/examples/60_search/example_successive_halving.py @@ -18,220 +18,219 @@ import autosklearn.classification -if __name__ == "__main__": - ############################################################################ - # Define a callback that instantiates SuccessiveHalving - # ===================================================== - - def get_smac_object_callback(budget_type): - def get_smac_object( - scenario_dict, - seed, - ta, - ta_kwargs, - metalearning_configurations, - n_jobs, - dask_client, - ): - from smac.facade.smac_ac_facade import SMAC4AC - from smac.intensification.successive_halving import SuccessiveHalving - from smac.runhistory.runhistory2epm import RunHistory2EPM4LogCost - from smac.scenario.scenario import Scenario - - if n_jobs > 1 or (dask_client and len(dask_client.nthreads()) > 1): - raise ValueError("Please make sure to guard the code invoking Auto-sklearn by " - "`if __name__ == '__main__'` and remove this exception.") - - scenario = Scenario(scenario_dict) - if len(metalearning_configurations) > 0: - default_config = scenario.cs.get_default_configuration() - initial_configurations = [default_config] + metalearning_configurations - else: - initial_configurations = None - rh2EPM = RunHistory2EPM4LogCost - - ta_kwargs['budget_type'] = budget_type - - return SMAC4AC( - scenario=scenario, - rng=seed, - runhistory2epm=rh2EPM, - tae_runner=ta, - tae_runner_kwargs=ta_kwargs, - initial_configurations=initial_configurations, - run_id=seed, - intensifier=SuccessiveHalving, - intensifier_kwargs={ - 'initial_budget': 10.0, - 'max_budget': 100, - 'eta': 2, - 'min_chall': 1}, - n_jobs=n_jobs, - dask_client=dask_client, - ) - return get_smac_object - - - ############################################################################ - # Data Loading - # ============ - - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1, shuffle=True) - - ############################################################################ - # Build and fit a classifier - # ========================== - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=40, - per_run_time_limit=10, - tmp_folder='/tmp/autosklearn_sh_example_tmp', - output_folder='/tmp/autosklearn_sh_example_out', - disable_evaluator_output=False, - # 'holdout' with 'train_size'=0.67 is the default argument setting - # for AutoSklearnClassifier. It is explicitly specified in this example - # for demonstrational purpose. - resampling_strategy='holdout', - resampling_strategy_arguments={'train_size': 0.67}, - include_estimators=['extra_trees', 'gradient_boosting', 'random_forest', 'sgd', - 'passive_aggressive'], - include_preprocessors=['no_preprocessing'], - get_smac_object_callback=get_smac_object_callback('iterations'), - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - print(automl.show_models()) - predictions = automl.predict(X_test) - # Print statistics about the auto-sklearn run such as number of - # iterations, number of models failed with a time out. - print(automl.sprint_statistics()) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) - - ############################################################################ - # We can also use cross-validation with successive halving - # ======================================================== - - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1, shuffle=True) - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=40, - per_run_time_limit=10, - tmp_folder='/tmp/autosklearn_sh_example_tmp_01', - output_folder='/tmp/autosklearn_sh_example_out_01', - disable_evaluator_output=False, - resampling_strategy='cv', - include_estimators=['extra_trees', 'gradient_boosting', 'random_forest', 'sgd', - 'passive_aggressive'], - include_preprocessors=['no_preprocessing'], - get_smac_object_callback=get_smac_object_callback('iterations'), - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - # Print the final ensemble constructed by auto-sklearn. - print(automl.show_models()) - automl.refit(X_train, y_train) - predictions = automl.predict(X_test) - # Print statistics about the auto-sklearn run such as number of - # iterations, number of models failed with a time out. - print(automl.sprint_statistics()) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) - - ############################################################################ - # Use an iterative fit cross-validation with successive halving - # ============================================================= - - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1, shuffle=True) - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=40, - per_run_time_limit=10, - tmp_folder='/tmp/autosklearn_sh_example_tmp_cv_02', - output_folder='/tmp/autosklearn_sh_example_out_cv_02', - disable_evaluator_output=False, - resampling_strategy='cv-iterative-fit', - include_estimators=['extra_trees', 'gradient_boosting', 'random_forest', 'sgd', - 'passive_aggressive'], - include_preprocessors=['no_preprocessing'], - get_smac_object_callback=get_smac_object_callback('iterations'), - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - # Print the final ensemble constructed by auto-sklearn. - print(automl.show_models()) - automl.refit(X_train, y_train) - predictions = automl.predict(X_test) - # Print statistics about the auto-sklearn run such as number of - # iterations, number of models failed with a time out. - print(automl.sprint_statistics()) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) - - ############################################################################ - # Next, we see the use of subsampling as a budget in Auto-sklearn - # =============================================================== - - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1, shuffle=True) - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=40, - per_run_time_limit=10, - tmp_folder='/tmp/autosklearn_sh_example_tmp_03', - output_folder='/tmp/autosklearn_sh_example_out_03', - disable_evaluator_output=False, - # 'holdout' with 'train_size'=0.67 is the default argument setting - # for AutoSklearnClassifier. It is explicitly specified in this example - # for demonstrational purpose. - resampling_strategy='holdout', - resampling_strategy_arguments={'train_size': 0.67}, - get_smac_object_callback=get_smac_object_callback('subsample'), - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - # Print the final ensemble constructed by auto-sklearn. - print(automl.show_models()) - predictions = automl.predict(X_test) - # Print statistics about the auto-sklearn run such as number of - # iterations, number of models failed with a time out. - print(automl.sprint_statistics()) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) - - ############################################################################ - # Mixed budget approach - # ===================== - # Finally, there's a mixed budget type which uses iterations where possible and - # subsamples otherwise - - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1, shuffle=True) - - automl = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=40, - per_run_time_limit=10, - tmp_folder='/tmp/autosklearn_sh_example_tmp_04', - output_folder='/tmp/autosklearn_sh_example_out_04', - disable_evaluator_output=False, - # 'holdout' with 'train_size'=0.67 is the default argument setting - # for AutoSklearnClassifier. It is explicitly specified in this example - # for demonstrational purpose. - resampling_strategy='holdout', - resampling_strategy_arguments={'train_size': 0.67}, - include_estimators=['extra_trees', 'gradient_boosting', 'random_forest', 'sgd'], - get_smac_object_callback=get_smac_object_callback('mixed'), - ) - automl.fit(X_train, y_train, dataset_name='breast_cancer') - - # Print the final ensemble constructed by auto-sklearn. - print(automl.show_models()) - predictions = automl.predict(X_test) - # Print statistics about the auto-sklearn run such as number of - # iterations, number of models failed with a time out. - print(automl.sprint_statistics()) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) +############################################################################ +# Define a callback that instantiates SuccessiveHalving +# ===================================================== + +def get_smac_object_callback(budget_type): + def get_smac_object( + scenario_dict, + seed, + ta, + ta_kwargs, + metalearning_configurations, + n_jobs, + dask_client, + ): + from smac.facade.smac_ac_facade import SMAC4AC + from smac.intensification.successive_halving import SuccessiveHalving + from smac.runhistory.runhistory2epm import RunHistory2EPM4LogCost + from smac.scenario.scenario import Scenario + + if n_jobs > 1 or (dask_client and len(dask_client.nthreads()) > 1): + raise ValueError("Please make sure to guard the code invoking Auto-sklearn by " + "`if __name__ == '__main__'` and remove this exception.") + + scenario = Scenario(scenario_dict) + if len(metalearning_configurations) > 0: + default_config = scenario.cs.get_default_configuration() + initial_configurations = [default_config] + metalearning_configurations + else: + initial_configurations = None + rh2EPM = RunHistory2EPM4LogCost + + ta_kwargs['budget_type'] = budget_type + + return SMAC4AC( + scenario=scenario, + rng=seed, + runhistory2epm=rh2EPM, + tae_runner=ta, + tae_runner_kwargs=ta_kwargs, + initial_configurations=initial_configurations, + run_id=seed, + intensifier=SuccessiveHalving, + intensifier_kwargs={ + 'initial_budget': 10.0, + 'max_budget': 100, + 'eta': 2, + 'min_chall': 1}, + n_jobs=n_jobs, + dask_client=dask_client, + ) + return get_smac_object + + +############################################################################ +# Data Loading +# ============ + +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1, shuffle=True) + +############################################################################ +# Build and fit a classifier +# ========================== + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=40, + per_run_time_limit=10, + tmp_folder='/tmp/autosklearn_sh_example_tmp', + output_folder='/tmp/autosklearn_sh_example_out', + disable_evaluator_output=False, + # 'holdout' with 'train_size'=0.67 is the default argument setting + # for AutoSklearnClassifier. It is explicitly specified in this example + # for demonstrational purpose. + resampling_strategy='holdout', + resampling_strategy_arguments={'train_size': 0.67}, + include_estimators=['extra_trees', 'gradient_boosting', 'random_forest', 'sgd', + 'passive_aggressive'], + include_preprocessors=['no_preprocessing'], + get_smac_object_callback=get_smac_object_callback('iterations'), +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +print(automl.show_models()) +predictions = automl.predict(X_test) +# Print statistics about the auto-sklearn run such as number of +# iterations, number of models failed with a time out. +print(automl.sprint_statistics()) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) + +############################################################################ +# We can also use cross-validation with successive halving +# ======================================================== + +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1, shuffle=True) + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=40, + per_run_time_limit=10, + tmp_folder='/tmp/autosklearn_sh_example_tmp_01', + output_folder='/tmp/autosklearn_sh_example_out_01', + disable_evaluator_output=False, + resampling_strategy='cv', + include_estimators=['extra_trees', 'gradient_boosting', 'random_forest', 'sgd', + 'passive_aggressive'], + include_preprocessors=['no_preprocessing'], + get_smac_object_callback=get_smac_object_callback('iterations'), +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +# Print the final ensemble constructed by auto-sklearn. +print(automl.show_models()) +automl.refit(X_train, y_train) +predictions = automl.predict(X_test) +# Print statistics about the auto-sklearn run such as number of +# iterations, number of models failed with a time out. +print(automl.sprint_statistics()) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) + +############################################################################ +# Use an iterative fit cross-validation with successive halving +# ============================================================= + +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1, shuffle=True) + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=40, + per_run_time_limit=10, + tmp_folder='/tmp/autosklearn_sh_example_tmp_cv_02', + output_folder='/tmp/autosklearn_sh_example_out_cv_02', + disable_evaluator_output=False, + resampling_strategy='cv-iterative-fit', + include_estimators=['extra_trees', 'gradient_boosting', 'random_forest', 'sgd', + 'passive_aggressive'], + include_preprocessors=['no_preprocessing'], + get_smac_object_callback=get_smac_object_callback('iterations'), +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +# Print the final ensemble constructed by auto-sklearn. +print(automl.show_models()) +automl.refit(X_train, y_train) +predictions = automl.predict(X_test) +# Print statistics about the auto-sklearn run such as number of +# iterations, number of models failed with a time out. +print(automl.sprint_statistics()) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) + +############################################################################ +# Next, we see the use of subsampling as a budget in Auto-sklearn +# =============================================================== + +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1, shuffle=True) + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=40, + per_run_time_limit=10, + tmp_folder='/tmp/autosklearn_sh_example_tmp_03', + output_folder='/tmp/autosklearn_sh_example_out_03', + disable_evaluator_output=False, + # 'holdout' with 'train_size'=0.67 is the default argument setting + # for AutoSklearnClassifier. It is explicitly specified in this example + # for demonstrational purpose. + resampling_strategy='holdout', + resampling_strategy_arguments={'train_size': 0.67}, + get_smac_object_callback=get_smac_object_callback('subsample'), +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +# Print the final ensemble constructed by auto-sklearn. +print(automl.show_models()) +predictions = automl.predict(X_test) +# Print statistics about the auto-sklearn run such as number of +# iterations, number of models failed with a time out. +print(automl.sprint_statistics()) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) + +############################################################################ +# Mixed budget approach +# ===================== +# Finally, there's a mixed budget type which uses iterations where possible and +# subsamples otherwise + +X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1, shuffle=True) + +automl = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=40, + per_run_time_limit=10, + tmp_folder='/tmp/autosklearn_sh_example_tmp_04', + output_folder='/tmp/autosklearn_sh_example_out_04', + disable_evaluator_output=False, + # 'holdout' with 'train_size'=0.67 is the default argument setting + # for AutoSklearnClassifier. It is explicitly specified in this example + # for demonstrational purpose. + resampling_strategy='holdout', + resampling_strategy_arguments={'train_size': 0.67}, + include_estimators=['extra_trees', 'gradient_boosting', 'random_forest', 'sgd'], + get_smac_object_callback=get_smac_object_callback('mixed'), +) +automl.fit(X_train, y_train, dataset_name='breast_cancer') + +# Print the final ensemble constructed by auto-sklearn. +print(automl.show_models()) +predictions = automl.predict(X_test) +# Print statistics about the auto-sklearn run such as number of +# iterations, number of models failed with a time out. +print(automl.sprint_statistics()) +print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) diff --git a/examples/80_extending/example_extending_classification.py b/examples/80_extending/example_extending_classification.py index 2a7ffc0da0..888e55dbbd 100644 --- a/examples/80_extending/example_extending_classification.py +++ b/examples/80_extending/example_extending_classification.py @@ -118,35 +118,32 @@ def get_hyperparameter_search_space(dataset_properties=None): cs = MLPClassifier.get_hyperparameter_search_space() print(cs) +############################################################################ +# Data Loading +# ============ + +X, y = load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = train_test_split(X, y) + +############################################################################ +# Fit MLP classifier to the data +# ============================== + +clf = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=30, + per_run_time_limit=10, + include_estimators=['MLPClassifier'], + # Bellow two flags are provided to speed up calculations + # Not recommended for a real implementation + initial_configurations_via_metalearning=0, + smac_scenario_args={'runcount_limit': 5}, +) +clf.fit(X_train, y_train) + +############################################################################ +# Print test accuracy and statistics +# ================================== -if __name__ == "__main__": - - ############################################################################ - # Data Loading - # ============ - - X, y = load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = train_test_split(X, y) - - ############################################################################ - # Fit MLP classifier to the data - # ============================== - - clf = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=30, - per_run_time_limit=10, - include_estimators=['MLPClassifier'], - # Bellow two flags are provided to speed up calculations - # Not recommended for a real implementation - initial_configurations_via_metalearning=0, - smac_scenario_args={'runcount_limit': 5}, - ) - clf.fit(X_train, y_train) - - ############################################################################ - # Print test accuracy and statistics - # ================================== - - y_pred = clf.predict(X_test) - print("accuracy: ", sklearn.metrics.accuracy_score(y_pred, y_test)) - print(clf.show_models()) +y_pred = clf.predict(X_test) +print("accuracy: ", sklearn.metrics.accuracy_score(y_pred, y_test)) +print(clf.show_models()) diff --git a/examples/80_extending/example_extending_preprocessor.py b/examples/80_extending/example_extending_preprocessor.py index 8b9c865cc8..f83b6aec9f 100644 --- a/examples/80_extending/example_extending_preprocessor.py +++ b/examples/80_extending/example_extending_preprocessor.py @@ -92,41 +92,38 @@ def get_hyperparameter_search_space(dataset_properties=None): # Add LDA component to auto-sklearn. autosklearn.pipeline.components.feature_preprocessing.add_preprocessor(LDA) +############################################################################ +# Create dataset +# ============== -if __name__ == "__main__": - - ############################################################################ - # Create dataset - # ============== - - X, y = load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = train_test_split(X, y) - - ############################################################################ - # Configuration space - # =================== +X, y = load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = train_test_split(X, y) - cs = LDA.get_hyperparameter_search_space() - print(cs) +############################################################################ +# Configuration space +# =================== - ############################################################################ - # Fit the model using LDA as preprocessor - # ======================================= +cs = LDA.get_hyperparameter_search_space() +print(cs) - clf = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=30, - include_preprocessors=['LDA'], - # Bellow two flags are provided to speed up calculations - # Not recommended for a real implementation - initial_configurations_via_metalearning=0, - smac_scenario_args={'runcount_limit': 5}, - ) - clf.fit(X_train, y_train) +############################################################################ +# Fit the model using LDA as preprocessor +# ======================================= + +clf = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=30, + include_preprocessors=['LDA'], + # Bellow two flags are provided to speed up calculations + # Not recommended for a real implementation + initial_configurations_via_metalearning=0, + smac_scenario_args={'runcount_limit': 5}, +) +clf.fit(X_train, y_train) - ############################################################################ - # Print prediction score and statistics - # ===================================== +############################################################################ +# Print prediction score and statistics +# ===================================== - y_pred = clf.predict(X_test) - print("accuracy: ", sklearn.metrics.accuracy_score(y_pred, y_test)) - print(clf.show_models()) +y_pred = clf.predict(X_test) +print("accuracy: ", sklearn.metrics.accuracy_score(y_pred, y_test)) +print(clf.show_models()) diff --git a/examples/80_extending/example_extending_regression.py b/examples/80_extending/example_extending_regression.py index c62c6857fb..e6ee7edd1f 100644 --- a/examples/80_extending/example_extending_regression.py +++ b/examples/80_extending/example_extending_regression.py @@ -103,34 +103,31 @@ def get_hyperparameter_search_space(dataset_properties=None): cs = KernelRidgeRegression.get_hyperparameter_search_space() print(cs) +############################################################################ +# Generate data +# ============= + +X, y = load_diabetes(return_X_y=True) +X_train, X_test, y_train, y_test = train_test_split(X, y) -if __name__ == "__main__": - - ############################################################################ - # Generate data - # ============= - - X, y = load_diabetes(return_X_y=True) - X_train, X_test, y_train, y_test = train_test_split(X, y) - - ############################################################################ - # Fit the model using KRR - # ======================= - - reg = autosklearn.regression.AutoSklearnRegressor( - time_left_for_this_task=30, - per_run_time_limit=10, - include_estimators=['KernelRidgeRegression'], - # Bellow two flags are provided to speed up calculations - # Not recommended for a real implementation - initial_configurations_via_metalearning=0, - smac_scenario_args={'runcount_limit': 5}, - ) - reg.fit(X_train, y_train) - - ############################################################################ - # Print prediction score and statistics - # ===================================== - y_pred = reg.predict(X_test) - print("r2 score: ", sklearn.metrics.r2_score(y_pred, y_test)) - print(reg.show_models()) +############################################################################ +# Fit the model using KRR +# ======================= + +reg = autosklearn.regression.AutoSklearnRegressor( + time_left_for_this_task=30, + per_run_time_limit=10, + include_estimators=['KernelRidgeRegression'], + # Bellow two flags are provided to speed up calculations + # Not recommended for a real implementation + initial_configurations_via_metalearning=0, + smac_scenario_args={'runcount_limit': 5}, +) +reg.fit(X_train, y_train) + +############################################################################ +# Print prediction score and statistics +# ===================================== +y_pred = reg.predict(X_test) +print("r2 score: ", sklearn.metrics.r2_score(y_pred, y_test)) +print(reg.show_models()) diff --git a/examples/80_extending/example_restrict_number_of_hyperparameters.py b/examples/80_extending/example_restrict_number_of_hyperparameters.py index 0a21da4a71..f432275eee 100644 --- a/examples/80_extending/example_restrict_number_of_hyperparameters.py +++ b/examples/80_extending/example_restrict_number_of_hyperparameters.py @@ -21,115 +21,114 @@ from autosklearn.pipeline.constants import DENSE, UNSIGNED_DATA, PREDICTIONS, SPARSE -if __name__ == "__main__": - ############################################################################ - # Subclass auto-sklearn's random forest classifier - # ================================================ - - # This classifier only has one of the hyperparameter's of auto-sklearn's - # default parametrization (``max_features``). Instead, it also - # tunes the number of estimators (``n_estimators``). - - class CustomRandomForest(AutoSklearnClassificationAlgorithm): - def __init__(self, - n_estimators, - max_features, - random_state=None, - ): - self.n_estimators = n_estimators - self.max_features = max_features - self.random_state = random_state - - def fit(self, X, y): - from sklearn.ensemble import RandomForestClassifier - - self.n_estimators = int(self.n_estimators) - - if self.max_features not in ("sqrt", "log2", "auto"): - max_features = int(X.shape[1] ** float(self.max_features)) - else: - max_features = self.max_features - - self.estimator = RandomForestClassifier( - n_estimators=self.n_estimators, - max_features=max_features, - random_state=self.random_state, - ) - self.estimator.fit(X, y) - return self - - def predict(self, X): - if self.estimator is None: - raise NotImplementedError() - return self.estimator.predict(X) - - def predict_proba(self, X): - if self.estimator is None: - raise NotImplementedError() - return self.estimator.predict_proba(X) - - @staticmethod - def get_properties(dataset_properties=None): - return {'shortname': 'RF', - 'name': 'Random Forest Classifier', - 'handles_regression': False, - 'handles_classification': True, - 'handles_multiclass': True, - 'handles_multilabel': True, - 'handles_multioutput': False, - 'is_deterministic': True, - 'input': (DENSE, SPARSE, UNSIGNED_DATA), - 'output': (PREDICTIONS,)} - - @staticmethod - def get_hyperparameter_search_space(dataset_properties=None): - cs = ConfigurationSpace() - - # The maximum number of features used in the forest is calculated as m^max_features, where - # m is the total number of features, and max_features is the hyperparameter specified below. - # The default is 0.5, which yields sqrt(m) features as max_features in the estimator. This - # corresponds with Geurts' heuristic. - max_features = UniformFloatHyperparameter("max_features", 0., 1., default_value=0.5) - n_estimators = UniformIntegerHyperparameter("n_estimators", 10, 1000, default_value=100) - - cs.add_hyperparameters([max_features, n_estimators]) - return cs - - - # Add custom random forest classifier component to auto-sklearn. - autosklearn.pipeline.components.classification.add_classifier(CustomRandomForest) - cs = CustomRandomForest.get_hyperparameter_search_space() - print(cs) - - ############################################################################ - # Data Loading - # ============ - - X, y = load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = train_test_split(X, y) - - ############################################################################ - # Fit Random forest classifier to the data - # ======================================== - - clf = autosklearn.classification.AutoSklearnClassifier( - time_left_for_this_task=30, - per_run_time_limit=10, - # Here we exclude auto-sklearn's default random forest component - exclude_estimators=['random_forest'], - # Bellow two flags are provided to speed up calculations - # Not recommended for a real implementation - initial_configurations_via_metalearning=0, - smac_scenario_args={'runcount_limit': 1}, - ) - clf.fit(X_train, y_train) - - ############################################################################ - # Print the configuration space - # ============================= - - # Observe that this configuration space only contains our custom random - # forest, but not auto-sklearn's ``random_forest`` - cs = clf.get_configuration_space(X_train, y_train) - assert 'random_forest' not in str(cs) - print(cs) +############################################################################ +# Subclass auto-sklearn's random forest classifier +# ================================================ + +# This classifier only has one of the hyperparameter's of auto-sklearn's +# default parametrization (``max_features``). Instead, it also +# tunes the number of estimators (``n_estimators``). + +class CustomRandomForest(AutoSklearnClassificationAlgorithm): + def __init__(self, + n_estimators, + max_features, + random_state=None, + ): + self.n_estimators = n_estimators + self.max_features = max_features + self.random_state = random_state + + def fit(self, X, y): + from sklearn.ensemble import RandomForestClassifier + + self.n_estimators = int(self.n_estimators) + + if self.max_features not in ("sqrt", "log2", "auto"): + max_features = int(X.shape[1] ** float(self.max_features)) + else: + max_features = self.max_features + + self.estimator = RandomForestClassifier( + n_estimators=self.n_estimators, + max_features=max_features, + random_state=self.random_state, + ) + self.estimator.fit(X, y) + return self + + def predict(self, X): + if self.estimator is None: + raise NotImplementedError() + return self.estimator.predict(X) + + def predict_proba(self, X): + if self.estimator is None: + raise NotImplementedError() + return self.estimator.predict_proba(X) + + @staticmethod + def get_properties(dataset_properties=None): + return {'shortname': 'RF', + 'name': 'Random Forest Classifier', + 'handles_regression': False, + 'handles_classification': True, + 'handles_multiclass': True, + 'handles_multilabel': True, + 'handles_multioutput': False, + 'is_deterministic': True, + 'input': (DENSE, SPARSE, UNSIGNED_DATA), + 'output': (PREDICTIONS,)} + + @staticmethod + def get_hyperparameter_search_space(dataset_properties=None): + cs = ConfigurationSpace() + + # The maximum number of features used in the forest is calculated as m^max_features, where + # m is the total number of features, and max_features is the hyperparameter specified below. + # The default is 0.5, which yields sqrt(m) features as max_features in the estimator. This + # corresponds with Geurts' heuristic. + max_features = UniformFloatHyperparameter("max_features", 0., 1., default_value=0.5) + n_estimators = UniformIntegerHyperparameter("n_estimators", 10, 1000, default_value=100) + + cs.add_hyperparameters([max_features, n_estimators]) + return cs + + +# Add custom random forest classifier component to auto-sklearn. +autosklearn.pipeline.components.classification.add_classifier(CustomRandomForest) +cs = CustomRandomForest.get_hyperparameter_search_space() +print(cs) + +############################################################################ +# Data Loading +# ============ + +X, y = load_breast_cancer(return_X_y=True) +X_train, X_test, y_train, y_test = train_test_split(X, y) + +############################################################################ +# Fit Random forest classifier to the data +# ======================================== + +clf = autosklearn.classification.AutoSklearnClassifier( + time_left_for_this_task=30, + per_run_time_limit=10, + # Here we exclude auto-sklearn's default random forest component + exclude_estimators=['random_forest'], + # Bellow two flags are provided to speed up calculations + # Not recommended for a real implementation + initial_configurations_via_metalearning=0, + smac_scenario_args={'runcount_limit': 1}, +) +clf.fit(X_train, y_train) + +############################################################################ +# Print the configuration space +# ============================= + +# Observe that this configuration space only contains our custom random +# forest, but not auto-sklearn's ``random_forest`` +cs = clf.get_configuration_space(X_train, y_train) +assert 'random_forest' not in str(cs) +print(cs) diff --git a/test/test_util/test_single_thread_client.py b/test/test_util/test_single_thread_client.py new file mode 100644 index 0000000000..34fe7736fe --- /dev/null +++ b/test/test_util/test_single_thread_client.py @@ -0,0 +1,27 @@ +import dask.distributed + +from distributed.utils_test import inc + +import pytest + +from autosklearn.util.single_thread_client import SingleThreadedClient + + +def test_single_thread_client_like_dask_client(): + single_thread_client = SingleThreadedClient() + assert isinstance(single_thread_client, dask.distributed.Client) + future = single_thread_client.submit(inc, 1) + assert isinstance(future, dask.distributed.Future) + assert future.done() + assert future.result() == 2 + assert sum(single_thread_client.nthreads().values()) == 1 + single_thread_client.close() + single_thread_client.shutdown() + + # Client/Futures are printed, so make sure str works + # str calls __rpr__ which is the purpose of below check + assert str(future) != "" + assert str(single_thread_client) != "" + + with pytest.raises(NotImplementedError): + single_thread_client.get_scheduler_logs() From d96f9ce53f5b1e061d6cd69db1fc55ddc7312e39 Mon Sep 17 00:00:00 2001 From: Matthias Feurer Date: Wed, 20 Jan 2021 15:51:18 +0100 Subject: [PATCH 08/10] Update examples on parallel processing (#1059) --- .../example_parallel_manual_spawning.py | 168 -------------- .../example_parallel_manual_spawning_cli.py | 215 ++++++++++++++++++ ...example_parallel_manual_spawning_python.py | 154 +++++++++++++ 3 files changed, 369 insertions(+), 168 deletions(-) delete mode 100644 examples/60_search/example_parallel_manual_spawning.py create mode 100644 examples/60_search/example_parallel_manual_spawning_cli.py create mode 100644 examples/60_search/example_parallel_manual_spawning_python.py diff --git a/examples/60_search/example_parallel_manual_spawning.py b/examples/60_search/example_parallel_manual_spawning.py deleted file mode 100644 index 53be0c2bca..0000000000 --- a/examples/60_search/example_parallel_manual_spawning.py +++ /dev/null @@ -1,168 +0,0 @@ -# -*- encoding: utf-8 -*- -""" -=========================================== -Parallel Usage with manual process spawning -=========================================== - -*Auto-sklearn* uses -`dask.distributed _ -for parallel optimization. - -This example shows how to spawn workers for *Auto-sklearn* manually. -Use this example as a starting point to parallelize *Auto-sklearn* -across multiple machines. To run *Auto-sklearn* in parallel -on a single machine check out the example -`Parallel Usage on a single machine `_. -""" - -import asyncio -import multiprocessing -import subprocess -import time - -import dask -import dask.distributed -import sklearn.datasets -import sklearn.metrics - -from autosklearn.classification import AutoSklearnClassifier -from autosklearn.constants import MULTICLASS_CLASSIFICATION - -tmp_folder = '/tmp/autosklearn_parallel_2_example_tmp' -output_folder = '/tmp/autosklearn_parallel_2_example_out' - - -############################################################################ -# Dask configuration -# ================== -# -# Auto-sklearn uses threads in Dask to launch memory constrained jobs. -# This number of threads can be provided directly via the n_jobs argument -# when creating the AutoSklearnClassifier. Additionally, the user can provide -# a dask_client argument which can have processes=True. -# When using processes to True, we need to specify the below setting -# to allow internally generated processes. -# Optionally, you can choose to provide a dask client with processes=False -# and remove the following line. -dask.config.set({'distributed.worker.daemon': False}) - - -############################################################################ -# Start worker - Python -# ===================== -# -# This function demonstrates how to start a dask worker from python. This -# is a bit cumbersome and should ideally be done from the command line. -# We do it here for illustrational purpose, butalso start one worker from -# the command line below. - -# Check the dask docs at -# https://docs.dask.org/en/latest/setup/python-advanced.html for further -# information. - - -def start_python_worker(scheduler_address): - dask.config.set({'distributed.worker.daemon': False}) - - async def do_work(): - async with dask.distributed.Nanny( - scheduler_ip=scheduler_address, - nthreads=1, - lifetime=35, # automatically shut down the worker so this loop ends - ) as worker: - await worker.finished() - - asyncio.get_event_loop().run_until_complete(do_work()) - - -############################################################################ -# Start worker - CLI -# ================== -# -# It is also possible to start dask workers from the command line (in fact, -# one can also start a dask scheduler from the command line), see the -# `dask cli docs `_ for -# further information. -# Please not, that DASK_DISTRIBUTED__WORKER__DAEMON=False is required in this -# case as dask-worker creates a new process. That is, it is equivalent to the -# setting described above with dask.distributed.Client with processes=True -# -# Again, we need to make sure that we do not start the workers in a daemon -# mode. - -def start_cli_worker(scheduler_address): - call_string = ( - "DASK_DISTRIBUTED__WORKER__DAEMON=False " - "dask-worker %s --nthreads 1 --lifetime 35" - ) % scheduler_address - proc = subprocess.run(call_string, stdout=subprocess.PIPE, - stderr=subprocess.PIPE, shell=True, check=True) - while proc.returncode is None: - time.sleep(1) - - -############################################################################ -# Start Auto-sklearn -# ================== -# -# We are now ready to start *auto-sklearn. -# -# To use auto-sklearn in parallel we must guard the code with -# ``if __name__ == '__main__'``. We then start a dask cluster as a context, -# which means that it is automatically stopped one all computation is done. -if __name__ == '__main__': - X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) - X_train, X_test, y_train, y_test = \ - sklearn.model_selection.train_test_split(X, y, random_state=1) - - # Create a dask compute cluster and a client manually - the former can - # be done via command line, too. - with dask.distributed.LocalCluster( - n_workers=0, processes=True, threads_per_worker=1, - ) as cluster, dask.distributed.Client(address=cluster.scheduler_address) as client: - - # now we start the two workers, one from within Python, the other - # via the command line. - process_python_worker = multiprocessing.Process( - target=start_python_worker, - args=(cluster.scheduler_address,), - ) - process_python_worker.start() - process_cli_worker = multiprocessing.Process( - target=start_cli_worker, - args=(cluster.scheduler_address,), - ) - process_cli_worker.start() - - # Wait a second for workers to become available - time.sleep(1) - - automl = AutoSklearnClassifier( - time_left_for_this_task=30, - per_run_time_limit=10, - memory_limit=1024, - tmp_folder=tmp_folder, - output_folder=output_folder, - seed=777, - # n_jobs is ignored internally as we pass a dask client. - n_jobs=1, - # Pass a dask client which connects to the previously constructed cluster. - dask_client=client, - ) - automl.fit(X_train, y_train) - - automl.fit_ensemble( - y_train, - task=MULTICLASS_CLASSIFICATION, - dataset_name='digits', - ensemble_size=20, - ensemble_nbest=50, - ) - - predictions = automl.predict(X_test) - print(automl.sprint_statistics()) - print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) - - # Wait until all workers are closed - process_python_worker.join() - process_cli_worker.join() diff --git a/examples/60_search/example_parallel_manual_spawning_cli.py b/examples/60_search/example_parallel_manual_spawning_cli.py new file mode 100644 index 0000000000..9eac944588 --- /dev/null +++ b/examples/60_search/example_parallel_manual_spawning_cli.py @@ -0,0 +1,215 @@ +# -*- encoding: utf-8 -*- +""" +====================================================== +Parallel Usage: Spawning workers from the command line +====================================================== + +*Auto-sklearn* uses +`dask.distributed `_ +for parallel optimization. + +This example shows how to start the dask scheduler and spawn +workers for *Auto-sklearn* manually from the command line. Use this example +as a starting point to parallelize *Auto-sklearn* across multiple +machines. If you want to start everything manually from within Python +please see `this example `_. +To run *Auto-sklearn* in parallel on a single machine check out the example +`Parallel Usage on a single machine `_. + +You can learn more about the dask command line interface from +https://docs.dask.org/en/latest/setup/cli.html. + +When manually passing a dask client to Auto-sklearn, all logic +must be guarded by ``if __name__ == "__main__":`` statements! We use +multiple such statements to properly render this example as a notebook +and also allow execution via the command line. + +Background +========== + +To run Auto-sklearn distributed on multiple machines we need to set +up three components: + +1. **Auto-sklearn and a dask client**. This will manage all workload, find new + configurations to evaluate and submit jobs via a dask client. As this + runs Bayesian optimization it should be executed on its own CPU. +2. **The dask workers**. They will do the actual work of running machine + learning algorithms and require their own CPU each. +3. **The scheduler**. It manages the communication between the dask client + and the different dask workers. As the client and all workers connect + to the scheduler it must be started first. This is a light-weight job + and does not require its own CPU. + +We will now start these three components in reverse order: scheduler, +workers and client. Also, in a real setup, the scheduler and the workers should +be started from the command line and not from within a Python file via +the ``subprocess`` module as done here (for the sake of having a self-contained +example). +""" + +########################################################################### +# Import statements +# ================= + +import multiprocessing +import subprocess +import time + +import dask.distributed +import sklearn.datasets +import sklearn.metrics + +from autosklearn.classification import AutoSklearnClassifier +from autosklearn.constants import MULTICLASS_CLASSIFICATION + +tmp_folder = '/tmp/autosklearn_parallel_3_example_tmp' +output_folder = '/tmp/autosklearn_parallel_3_example_out' + +worker_processes = [] + + +########################################################################### +# 0. Setup client-scheduler communication +# ======================================= +# +# In this examples the dask scheduler is started without an explicit +# address and port. Instead, the scheduler takes a free port and stores +# relevant information in a file for which we provided the name and +# location. This filename is also given to the worker so they can find all +# relevant information to connect to the scheduler. + +scheduler_file_name = 'scheduler-file.json' + + +############################################################################ +# 1. Start scheduler +# ================== +# +# Starting the scheduler is done with the following bash command: +# +# .. code:: bash +# +# dask-scheduler --scheduler-file scheduler-file.json --idle-timeout 10 +# +# We will now execute this bash command from within Python to have a +# self-contained example: + +def cli_start_scheduler(scheduler_file_name): + call_string = ( + "dask-scheduler --scheduler-file %s --idle-timeout 10" + ) % scheduler_file_name + proc = subprocess.run(call_string, stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, shell=True, check=True) + while proc.returncode is None: + time.sleep(1) + + +if __name__ == "__main__": + process_python_worker = multiprocessing.Process( + target=cli_start_scheduler, + args=(scheduler_file_name, ), + ) + process_python_worker.start() + worker_processes.append(process_python_worker) + + # Wait a second for the scheduler to become available + time.sleep(1) + + +############################################################################ +# 2. Start two workers +# ==================== +# +# Starting the scheduler is done with the following bash command: +# +# .. code:: bash +# +# DASK_DISTRIBUTED__WORKER__DAEMON=False \ +# dask-worker --nthreads 1 --lifetime 35 --memory-limit 0 \ +# --scheduler-file scheduler-file.json +# +# We will now execute this bash command from within Python to have a +# self-contained example. Please note, that +# ``DASK_DISTRIBUTED__WORKER__DAEMON=False`` is required in this +# case as dask-worker creates a new process, which by default is not +# compatible with Auto-sklearn creating new processes in the workers itself. +# We disable dask's memory management by passing ``--memory-limit`` as +# Auto-sklearn does the memory management itself. + +def cli_start_worker(scheduler_file_name): + call_string = ( + "DASK_DISTRIBUTED__WORKER__DAEMON=False " + "dask-worker --nthreads 1 --lifetime 35 --memory-limit 0 " + "--scheduler-file %s" + ) % scheduler_file_name + proc = subprocess.run(call_string, stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, shell=True) + while proc.returncode is None: + time.sleep(1) + +if __name__ == '__main__': + for _ in range(2): + process_cli_worker = multiprocessing.Process( + target=cli_start_worker, + args=(scheduler_file_name, ), + ) + process_cli_worker.start() + worker_processes.append(process_cli_worker) + + # Wait a second for workers to become available + time.sleep(1) + +############################################################################ +# 3. Creating a client in Python +# ============================== +# +# Finally we create a dask cluster which also connects to the scheduler via +# the information in the file created by the scheduler. + +client = dask.distributed.Client(scheduler_file=scheduler_file_name) + +############################################################################ +# Start Auto-sklearn +# ~~~~~~~~~~~~~~~~~~ +if __name__ == "__main__": + X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) + X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) + automl = AutoSklearnClassifier( + time_left_for_this_task=30, + per_run_time_limit=10, + memory_limit=1024, + tmp_folder=tmp_folder, + output_folder=output_folder, + seed=777, + # n_jobs is ignored internally as we pass a dask client. + n_jobs=1, + # Pass a dask client which connects to the previously constructed cluster. + dask_client=client, + ) + automl.fit(X_train, y_train) + + automl.fit_ensemble( + y_train, + task=MULTICLASS_CLASSIFICATION, + dataset_name='digits', + ensemble_size=20, + ensemble_nbest=50, + ) + + predictions = automl.predict(X_test) + print(automl.sprint_statistics()) + print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) + + +############################################################################ +# Wait until all workers are closed +# ================================= +# +# This is only necessary if the workers are started from within this python +# script. In a real application one would start them directly from the command +# line. +if __name__ == '__main__': + process_python_worker.join() + for process in worker_processes: + process.join() diff --git a/examples/60_search/example_parallel_manual_spawning_python.py b/examples/60_search/example_parallel_manual_spawning_python.py new file mode 100644 index 0000000000..491fd0bc0c --- /dev/null +++ b/examples/60_search/example_parallel_manual_spawning_python.py @@ -0,0 +1,154 @@ +# -*- encoding: utf-8 -*- +""" +=================================================== +Parallel Usage: Spawning workers from within Python +=================================================== + +*Auto-sklearn* uses +`dask.distributed `_ +for parallel optimization. + +This example shows how to start the dask scheduler and spawn +workers for *Auto-sklearn* manually within Python. Use this example +as a starting point to parallelize *Auto-sklearn* across multiple +machines. If you want to start everything manually from the command line +please see `this example `_. +To run *Auto-sklearn* in parallel on a single machine check out the example +`Parallel Usage on a single machine `_. + +When manually passing a dask client to Auto-sklearn, all logic +must be guarded by ``if __name__ == "__main__":`` statements! We use +multiple such statements to properly render this example as a notebook +and also allow execution via the command line. + +Background +========== + +To run Auto-sklearn distributed on multiple machines we need to set +up three components: + +1. **Auto-sklearn and a dask client**. This will manage all workload, find new + configurations to evaluate and submit jobs via a dask client. As this + runs Bayesian optimization it should be executed on its own CPU. +2. **The dask workers**. They will do the actual work of running machine + learning algorithms and require their own CPU each. +3. **The scheduler**. It manages the communication between the dask client + and the different dask workers. As the client and all workers connect + to the scheduler it must be started first. This is a light-weight job + and does not require its own CPU. + +We will now start these three components in reverse order: scheduler, +workers and client. Also, in a real setup, the scheduler and the workers should +be started from the command line and not from within a Python file via +the ``subprocess`` module as done here (for the sake of having a self-contained +example). +""" + +import asyncio +import multiprocessing +import time + +import dask +import dask.distributed +import sklearn.datasets +import sklearn.metrics + +from autosklearn.classification import AutoSklearnClassifier +from autosklearn.constants import MULTICLASS_CLASSIFICATION + +tmp_folder = '/tmp/autosklearn_parallel_2_example_tmp' +output_folder = '/tmp/autosklearn_parallel_2_example_out' + + +############################################################################ +# Define function to start worker +# =============================== +# +# Define the function to start a dask worker from python. This +# is a bit cumbersome and should ideally be done from the command line. +# We do it here only for illustrational purpose. + +# Check the dask docs at +# https://docs.dask.org/en/latest/setup/python-advanced.html for further +# information. + +def start_python_worker(scheduler_address): + dask.config.set({'distributed.worker.daemon': False}) + + async def do_work(): + async with dask.distributed.Nanny( + scheduler_ip=scheduler_address, + nthreads=1, + lifetime=35, # automatically shut down the worker so this loop ends + memory_limit=0, # Disable memory management as it is done by Auto-sklearn itself + ) as worker: + await worker.finished() + + asyncio.get_event_loop().run_until_complete(do_work()) + + +############################################################################ +# Start Auto-sklearn +# ================== +# +# We are now ready to start *auto-sklearn and all dask related processes. +# +# To use auto-sklearn in parallel we must guard the code with +# ``if __name__ == '__main__'``. We then start a dask cluster as a context, +# which means that it is automatically stopped once all computation is done. +if __name__ == '__main__': + X, y = sklearn.datasets.load_breast_cancer(return_X_y=True) + X_train, X_test, y_train, y_test = \ + sklearn.model_selection.train_test_split(X, y, random_state=1) + + # 1. Create a dask scheduler (LocalCluster) + with dask.distributed.LocalCluster( + n_workers=0, processes=True, threads_per_worker=1, + ) as cluster: + + # 2. Start the workers + # now we start the two workers, one from within Python, the other + # via the command line. + worker_processes = [] + for _ in range(2): + process_python_worker = multiprocessing.Process( + target=start_python_worker, + args=(cluster.scheduler_address, ), + ) + process_python_worker.start() + worker_processes.append(process_python_worker) + + # Wait a second for workers to become available + time.sleep(1) + + # 3. Start the client + with dask.distributed.Client(address=cluster.scheduler_address) as client: + automl = AutoSklearnClassifier( + time_left_for_this_task=30, + per_run_time_limit=10, + memory_limit=1024, + tmp_folder=tmp_folder, + output_folder=output_folder, + seed=777, + # n_jobs is ignored internally as we pass a dask client. + n_jobs=1, + # Pass a dask client which connects to the previously constructed cluster. + dask_client=client, + ) + automl.fit(X_train, y_train) + + automl.fit_ensemble( + y_train, + task=MULTICLASS_CLASSIFICATION, + dataset_name='digits', + ensemble_size=20, + ensemble_nbest=50, + ) + + predictions = automl.predict(X_test) + print(automl.sprint_statistics()) + print("Accuracy score", sklearn.metrics.accuracy_score(y_test, predictions)) + + # Wait until all workers are closed + for process in worker_processes: + process_python_worker.join() From 26760aa25524cc1acc050cbf097933104a048636 Mon Sep 17 00:00:00 2001 From: Matthias Feurer Date: Thu, 21 Jan 2021 15:13:36 +0100 Subject: [PATCH 09/10] Threads and forkserver (#1062) * use threads again * try the forkserver * use forkserver pre-load for faster process starting * streamline code * de-duplicate code * add missing file * Update parallel.py * Update parallel.py --- autosklearn/automl.py | 10 +++++----- autosklearn/ensemble_builder.py | 4 +++- autosklearn/evaluation/__init__.py | 6 ++++-- autosklearn/util/parallel.py | 20 +++++++++++++++++++ ...un_auto-sklearn_for_metadata_generation.py | 3 ++- test/conftest.py | 5 +---- test/test_ensemble_builder/test_ensemble.py | 4 ++-- test/test_evaluation/test_evaluation.py | 5 +++++ 8 files changed, 42 insertions(+), 15 deletions(-) create mode 100644 autosklearn/util/parallel.py diff --git a/autosklearn/automl.py b/autosklearn/automl.py index 0dbaa78b56..b5e8a9becd 100644 --- a/autosklearn/automl.py +++ b/autosklearn/automl.py @@ -49,6 +49,7 @@ get_named_client_logger, ) from autosklearn.util import pipeline, RE_PATTERN +from autosklearn.util.parallel import preload_modules from autosklearn.ensemble_builder import EnsembleBuilderManager from autosklearn.ensembles.singlebest_ensemble import SingleBest from autosklearn.smbo import AutoMLSMBO @@ -228,7 +229,7 @@ def __init__(self, # examples. Nevertheless, multi-process runs # have spawn as requirement to reduce the # possibility of a deadlock - self._multiprocessing_context = 'spawn' + self._multiprocessing_context = 'forkserver' if self._n_jobs == 1 and self._dask_client is None: self._multiprocessing_context = 'fork' self._dask_client = SingleThreadedClient() @@ -248,11 +249,10 @@ def __init__(self, def _create_dask_client(self): self._is_dask_client_internally_created = True - dask.config.set({'distributed.worker.daemon': False}) self._dask_client = dask.distributed.Client( dask.distributed.LocalCluster( n_workers=self._n_jobs, - processes=True if self._n_jobs != 1 else False, + processes=False, threads_per_worker=1, # We use the temporal directory to save the # dask workers, because deleting workers @@ -299,8 +299,8 @@ def _get_logger(self, name): # under the above logging configuration setting # We need to specify the logger_name so that received records # are treated under the logger_name ROOT logger setting - context = multiprocessing.get_context( - self._multiprocessing_context) + context = multiprocessing.get_context(self._multiprocessing_context) + preload_modules(context) self.stop_logging_server = context.Event() port = context.Value('l') # be safe by using a long port.value = -1 diff --git a/autosklearn/ensemble_builder.py b/autosklearn/ensemble_builder.py index 3b1b2d8241..56db171197 100644 --- a/autosklearn/ensemble_builder.py +++ b/autosklearn/ensemble_builder.py @@ -31,6 +31,7 @@ from autosklearn.ensembles.ensemble_selection import EnsembleSelection from autosklearn.ensembles.abstract_ensemble import AbstractEnsemble from autosklearn.util.logging_ import get_named_client_logger +from autosklearn.util.parallel import preload_modules Y_ENSEMBLE = 0 Y_VALID = 1 @@ -572,11 +573,11 @@ def __init__( def run( self, iteration: int, + pynisher_context: str, time_left: Optional[float] = None, end_at: Optional[float] = None, time_buffer=5, return_predictions: bool = False, - pynisher_context: str = 'spawn', ): if time_left is None and end_at is None: @@ -606,6 +607,7 @@ def run( if wall_time_in_s < 1: break context = multiprocessing.get_context(pynisher_context) + preload_modules(context) safe_ensemble_script = pynisher.enforce_limits( wall_time_in_s=wall_time_in_s, diff --git a/autosklearn/evaluation/__init__.py b/autosklearn/evaluation/__init__.py index e6e78d0268..ff28d2d096 100644 --- a/autosklearn/evaluation/__init__.py +++ b/autosklearn/evaluation/__init__.py @@ -24,6 +24,7 @@ import autosklearn.evaluation.test_evaluator import autosklearn.evaluation.util from autosklearn.util.logging_ import get_named_client_logger +from autosklearn.util.parallel import preload_modules def fit_predict_try_except_decorator(ta, queue, cost_for_crash, **kwargs): @@ -97,12 +98,12 @@ def _encode_exit_status(exit_status): class ExecuteTaFuncWithQueue(AbstractTAFunc): def __init__(self, backend, autosklearn_seed, resampling_strategy, metric, - cost_for_crash, abort_on_first_run_crash, port, + cost_for_crash, abort_on_first_run_crash, port, pynisher_context, initial_num_run=1, stats=None, run_obj='quality', par_factor=1, scoring_functions=None, output_y_hat_optimization=True, include=None, exclude=None, memory_limit=None, disable_file_output=False, init_params=None, - budget_type=None, ta=False, pynisher_context='spawn', **resampling_strategy_args): + budget_type=None, ta=False, **resampling_strategy_args): if resampling_strategy == 'holdout': eval_function = autosklearn.evaluation.train_evaluator.eval_holdout @@ -261,6 +262,7 @@ def run( ) -> Tuple[StatusType, float, float, Dict[str, Union[int, float, str, Dict, List, Tuple]]]: context = multiprocessing.get_context(self.pynisher_context) + preload_modules(context) queue = context.Queue() if not (instance_specific is None or instance_specific == '0'): diff --git a/autosklearn/util/parallel.py b/autosklearn/util/parallel.py new file mode 100644 index 0000000000..2f0ea6b016 --- /dev/null +++ b/autosklearn/util/parallel.py @@ -0,0 +1,20 @@ +import multiprocessing +import sys + + +def preload_modules(context: multiprocessing.context.BaseContext) -> None: + all_loaded_modules = sys.modules.keys() + preload = [ + loaded_module for loaded_module in all_loaded_modules + if loaded_module.split('.')[0] in ( + 'smac', + 'autosklearn', + 'numpy', + 'scipy', + 'pandas', + 'pynisher', + 'sklearn', + 'ConfigSpace', + ) and 'logging' not in loaded_module + ] + context.set_forkserver_preload(preload) diff --git a/scripts/run_auto-sklearn_for_metadata_generation.py b/scripts/run_auto-sklearn_for_metadata_generation.py index 0fb9d144e3..dc68d70cde 100644 --- a/scripts/run_auto-sklearn_for_metadata_generation.py +++ b/scripts/run_auto-sklearn_for_metadata_generation.py @@ -151,7 +151,8 @@ include=include, metric=automl_arguments['metric'], cost_for_crash=get_cost_of_crash(automl_arguments['metric']), - abort_on_first_run_crash=False,) + abort_on_first_run_crash=False, + pynisher_context='fork') run_info, run_value = ta.run_wrapper( RunInfo( config=config, diff --git a/test/conftest.py b/test/conftest.py index 4047711b01..4db32f6af7 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -3,7 +3,6 @@ import time import unittest.mock -import dask from dask.distributed import Client, get_client import psutil import pytest @@ -125,8 +124,7 @@ def dask_client(request): Workers are in subprocesses to not create deadlocks with the pynisher and logging. """ - dask.config.set({'distributed.worker.daemon': False}) - client = Client(n_workers=2, threads_per_worker=1, processes=True) + client = Client(n_workers=2, threads_per_worker=1, processes=False) print("Started Dask client={}\n".format(client)) def get_finalizer(address): @@ -151,7 +149,6 @@ def dask_client_single_worker(request): it is used very rarely to avoid this issue as much as possible. """ - dask.config.set({'distributed.worker.daemon': False}) client = Client(n_workers=1, threads_per_worker=1, processes=False) print("Started Dask client={}\n".format(client)) diff --git a/test/test_ensemble_builder/test_ensemble.py b/test/test_ensemble_builder/test_ensemble.py index c6e4929b5e..e23de18a3c 100644 --- a/test/test_ensemble_builder/test_ensemble.py +++ b/test/test_ensemble_builder/test_ensemble.py @@ -504,7 +504,7 @@ def test_run_end_at(ensemble_backend): current_time = time.time() - ensbuilder.run(end_at=current_time + 10, iteration=1) + ensbuilder.run(end_at=current_time + 10, iteration=1, pynisher_context='forkserver') # 4 seconds left because: 10 seconds - 5 seconds overhead - very little overhead, # but then rounded to an integer assert pynisher_mock.call_args_list[0][1]["wall_time_in_s"], 4 @@ -579,7 +579,7 @@ def mtime_mock(filename): # And then it still runs, but basically won't do anything any more except for raising error # messages via the logger - ensbuilder.run(time_left=1000, iteration=0) + ensbuilder.run(time_left=1000, iteration=0, pynisher_context='fork') assert os.path.exists(read_scores_file) assert not os.path.exists(read_preds_file) assert logger_mock.warning.call_count == 4 diff --git a/test/test_evaluation/test_evaluation.py b/test/test_evaluation/test_evaluation.py index e4d0550ca1..b76db52f1e 100644 --- a/test/test_evaluation/test_evaluation.py +++ b/test/test_evaluation/test_evaluation.py @@ -112,6 +112,7 @@ def test_zero_or_negative_cutoff(self, pynisher_mock): metric=accuracy, cost_for_crash=get_cost_of_crash(accuracy), abort_on_first_run_crash=False, + pynisher_context='forkserver', ) self.scenario.wallclock_limit = 5 self.stats.submitted_ta_runs += 1 @@ -130,6 +131,7 @@ def test_cutoff_lower_than_remaining_time(self, pynisher_mock): metric=accuracy, cost_for_crash=get_cost_of_crash(accuracy), abort_on_first_run_crash=False, + pynisher_context='forkserver', ) self.stats.ta_runs = 1 ta.run_wrapper(RunInfo(config=config, cutoff=30, instance=None, instance_specific=None, @@ -224,6 +226,7 @@ def test_eval_with_limits_holdout_fail_timeout(self, pynisher_mock): metric=accuracy, cost_for_crash=get_cost_of_crash(accuracy), abort_on_first_run_crash=False, + pynisher_context='forkserver', ) info = ta.run_wrapper(RunInfo(config=config, cutoff=30, instance=None, instance_specific=None, seed=1, capped=False)) @@ -259,6 +262,7 @@ def side_effect(**kwargs): metric=accuracy, cost_for_crash=get_cost_of_crash(accuracy), abort_on_first_run_crash=False, + pynisher_context='forkserver', ) info = ta.run_wrapper(RunInfo(config=config, cutoff=30, instance=None, instance_specific=None, seed=1, capped=False)) @@ -282,6 +286,7 @@ def side_effect(**kwargs): metric=accuracy, cost_for_crash=get_cost_of_crash(accuracy), abort_on_first_run_crash=False, + pynisher_context='forkserver', ) info = ta.run_wrapper(RunInfo(config=config, cutoff=30, instance=None, instance_specific=None, seed=1, capped=False)) From 0801bcb799d9a239454c73849f5cf5e377989589 Mon Sep 17 00:00:00 2001 From: Matthias Feurer Date: Mon, 25 Jan 2021 15:00:49 +0100 Subject: [PATCH 10/10] prepare new release --- autosklearn/__version__.py | 2 +- doc/releases.rst | 22 ++++++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/autosklearn/__version__.py b/autosklearn/__version__.py index 299c1f6235..9b15ea92ca 100644 --- a/autosklearn/__version__.py +++ b/autosklearn/__version__.py @@ -1,4 +1,4 @@ """Version information.""" # The following line *must* be the last in the module, exactly as formatted: -__version__ = "0.12.2dev" +__version__ = "0.12.2" diff --git a/doc/releases.rst b/doc/releases.rst index 08b3d9fb04..bd069ac6b4 100644 --- a/doc/releases.rst +++ b/doc/releases.rst @@ -12,6 +12,28 @@ Releases ======== +Version 0.12.2 +============== + +* ADD #1045: New example demonstrating how to log multiple metrics during a run of Auto-sklearn. +* DOC #1052: Add links to mybinder +* DOC #1059: Improved the example on manually starting workers for Auto-sklearn. +* FIX #1046: Add the final result of the ensemble builder to the ensemble builder trajectory. +* MAINT: Two log outputs of level warning about metadata were turned reduced to the info loglevel + as they are not actionable for the user. +* MAINT #1062: Use threads for local dask workers and forkserver to start subprocesses to reduce + overhead. +* MAINT #1053: Remove the restriction to guard single-core Auto-sklearn by + ``__main__ == "__name__"`` again. + +Contributors v0.12.2 +******************** + +* Matthias Feurer +* ROHIT AGARWAL +* Francisco Rivera +* Katharina Eggensperger + Version 0.12.1 ==============