Skip to content

Commit

Permalink
Fix unittest?, increase coverage (hopefully)
Browse files Browse the repository at this point in the history
  • Loading branch information
mfeurer committed May 24, 2022
1 parent b01f1cb commit d8a863c
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 12 deletions.
12 changes: 9 additions & 3 deletions autosklearn/automl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2010,11 +2010,17 @@ def sprint_statistics(self) -> str:
)
)[0]
if len(idx_success) > 0:
key = (
"mean_test_score"
if len(self._metrics) == 1
else f"mean_test_" f"{self._metrics[0].name}"
)

if not self._metrics[0]._optimum:
idx_best_run = np.argmin(cv_results["mean_test_score"][idx_success])
idx_best_run = np.argmin(cv_results[key][idx_success])
else:
idx_best_run = np.argmax(cv_results["mean_test_score"][idx_success])
best_score = cv_results["mean_test_score"][idx_success][idx_best_run]
idx_best_run = np.argmax(cv_results[key][idx_success])
best_score = cv_results[key][idx_success][idx_best_run]
sio.write(" Best validation score: %f\n" % best_score)
num_runs = len(cv_results["status"])
sio.write(" Number of target algorithm runs: %d\n" % num_runs)
Expand Down
8 changes: 1 addition & 7 deletions examples/40_advanced/example_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,6 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
scorer = autosklearn.metrics.accuracy
cls = autosklearn.classification.AutoSklearnClassifier(
time_left_for_this_task=60,
per_run_time_limit=30,
seed=1,
metric=scorer,
)
Expand All @@ -107,7 +106,6 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
)
cls = autosklearn.classification.AutoSklearnClassifier(
time_left_for_this_task=60,
per_run_time_limit=30,
seed=1,
metric=accuracy_scorer,
)
Expand All @@ -133,7 +131,6 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
)
cls = autosklearn.classification.AutoSklearnClassifier(
time_left_for_this_task=60,
per_run_time_limit=30,
seed=1,
metric=error_rate,
)
Expand Down Expand Up @@ -184,7 +181,6 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
)
cls = autosklearn.classification.AutoSklearnClassifier(
time_left_for_this_task=60,
per_run_time_limit=30,
seed=1,
metric=error_rate,
)
Expand Down Expand Up @@ -217,10 +213,8 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
)
cls = autosklearn.classification.AutoSklearnClassifier(
time_left_for_this_task=60,
per_run_time_limit=30,
seed=1,
metric=accuracy_scorer,
ensemble_size=0,
)
cls.fit(X_train, y_train)

Expand All @@ -232,4 +226,4 @@ def metric_which_needs_x(solution, prediction, X_data, consider_col, val_thresho
consider_col=1,
val_threshold=18.8,
)
print(f"Error score {score:.3f} using {error_rate.name:s}")
print(f"Error score {score:.3f} using {accuracy_scorer.name:s}")
8 changes: 7 additions & 1 deletion test/test_automl/test_post_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def test__load_pareto_front(automl: AutoML) -> None:
"""
# Check that the predict function works
X = np.array([[1.0, 1.0, 1.0, 1.0]])
print(automl.predict(X))

assert automl.predict_proba(X).shape == (1, 3)
assert automl.predict(X).shape == (1,)

Expand All @@ -98,3 +98,9 @@ def test__load_pareto_front(automl: AutoML) -> None:
assert y_pred.shape == (1, 3)
y_pred = ensemble.predict(X)
assert y_pred in ["setosa", "versicolor", "virginica"]

statistics = automl.sprint_statistics()
assert "Metrics" in statistics
assert ("Best validation score: 0.9" in statistics) or (
"Best validation score: 1.0" in statistics
), statistics
2 changes: 1 addition & 1 deletion test/test_ensemble_builder/test_ensemble_builder_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def test_run_builds_valid_ensemble(builder: EnsembleBuilder) -> None:

assert mock_fit.call_count == 1
# Check that the ids of runs in the ensemble were all candidates
candidates = mock_fit.call_args.kwargs["candidates"]
candidates = mock_fit.call_args[1]["candidates"]
candidate_ids = {run.id for run in candidates}
assert ensemble_ids <= candidate_ids

Expand Down

0 comments on commit d8a863c

Please sign in to comment.