Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make multi-processing in baseline models more consistent #909

Merged
Prev Previous commit
Next Next commit
docstring format
sibre28 committed Jul 12, 2024
commit 320942d1a37724a1e65a9977add5f56dfb20a052
Original file line number Diff line number Diff line change
@@ -35,8 +35,11 @@ class BaselineClassifier:

Get a baseline by fitting data on multiple different models and comparing the best metrics.

Parameters ---------- extended_search: If set to true, an extended set of models will be used to fit the
classifier. This might result in significantly higher runtime.
Parameters
----------
extended_search:
If set to true, an extended set of models will be used to fit the classifier.
This might result in significantly higher runtime.
"""

def __init__(self, extended_search: bool = False):
11 changes: 7 additions & 4 deletions src/safeds/ml/classical/regression/_baseline_regressor.py
Original file line number Diff line number Diff line change
@@ -39,11 +39,14 @@ class BaselineRegressor:

Get a baseline by fitting data on multiple different models and comparing the best metrics.

Parameters ---------- extended_search: If set to true, an extended set of models will be used to fit the
classifier. This might result in significantly higher runtime.
Parameters
----------
extended_search:
If set to true, an extended set of models will be used to fit the classifier.
This might result in significantly higher runtime.
"""

def __init__(self, include_slower_models: bool = False):
def __init__(self, extended_search: bool = False):
self._is_fitted = False
self._list_of_model_types = [
AdaBoostRegressor(),
@@ -54,7 +57,7 @@ def __init__(self, include_slower_models: bool = False):
SupportVectorRegressor(),
]

if include_slower_models:
if extended_search:
self._list_of_model_types.extend(
[ElasticNetRegressor(), LassoRegressor(), GradientBoostingRegressor()],
) # pragma: no cover