Skip to content

Commit

Permalink
refactor: alpha attribute to _alpha
Browse files Browse the repository at this point in the history
  • Loading branch information
lars-reimann committed Apr 28, 2023
1 parent d103cb2 commit 31c36a3
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
10 changes: 5 additions & 5 deletions src/safeds/ml/classical/regression/_ridge_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def __init__(self, alpha: float = 1.0) -> None:
self._wrapped_regressor: sk_Ridge | None = None
self._feature_names: list[str] | None = None
self._target_name: str | None = None
self.alpha = alpha
if self.alpha < 0:
self._alpha = alpha
if self._alpha < 0:
raise ValueError("alpha must be non-negative")
if self.alpha == 0.0:
if self._alpha == 0.0:
warnings.warn(
(
"Setting alpha to zero makes this model equivalent to LinearRegression. You should use "
Expand Down Expand Up @@ -66,10 +66,10 @@ def fit(self, training_set: TaggedTable) -> RidgeRegression:
LearningError
If the training data contains invalid values or if the training failed.
"""
wrapped_regressor = sk_Ridge(alpha=self.alpha)
wrapped_regressor = sk_Ridge(alpha=self._alpha)
fit(wrapped_regressor, training_set)

result = RidgeRegression(alpha=self.alpha)
result = RidgeRegression(alpha=self._alpha)
result._wrapped_regressor = wrapped_regressor
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
4 changes: 2 additions & 2 deletions tests/safeds/ml/classical/regression/test_ridge_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,11 @@ def test_should_warn_if_alpha_is_zero() -> None:
def test_should_pass_alpha_to_fitted_regressor() -> None:
regressor = RidgeRegression(alpha=1.0)
fitted_regressor = regressor.fit(Table.from_dict({"A": [1, 2, 4], "B": [1, 2, 3]}).tag_columns("B"))
assert regressor.alpha == fitted_regressor.alpha
assert regressor._alpha == fitted_regressor._alpha


def test_should_pass_alpha_to_sklearn() -> None:
regressor = RidgeRegression(alpha=1.0)
fitted_regressor = regressor.fit(Table.from_dict({"A": [1, 2, 4], "B": [1, 2, 3]}).tag_columns("B"))
assert fitted_regressor._wrapped_regressor is not None
assert fitted_regressor._wrapped_regressor.alpha == fitted_regressor.alpha
assert fitted_regressor._wrapped_regressor.alpha == fitted_regressor._alpha

0 comments on commit 31c36a3

Please sign in to comment.