Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added alpha parameter to RidgeRegression #231

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9beea27
feat: Added alpha parameter to `RidgeRegression`
Marsmaennchen221 Apr 21, 2023
1717416
Merge branch 'main' of https://github.com/Safe-DS/Stdlib into 164-set…
Marsmaennchen221 Apr 21, 2023
182455c
refactor: Removed temporary test and print statements
Marsmaennchen221 Apr 21, 2023
ae9acde
refactor: Provided more information for the ruff linter to pass
Marsmaennchen221 Apr 21, 2023
4e10ebf
style: apply automated linter fixes
megalinter-bot Apr 21, 2023
4de039c
style: apply automated linter fixes
megalinter-bot Apr 21, 2023
8b9f396
Merge branch 'main' into 164-set-alpha-parameter-for-regularization-o…
lars-reimann Apr 25, 2023
a5f99ef
style: apply automated linter fixes
megalinter-bot Apr 25, 2023
a768484
test: Added match parameter and expected Exception to test
Marsmaennchen221 Apr 28, 2023
42db542
style: apply automated linter fixes
megalinter-bot Apr 28, 2023
e3b1940
Merge branch 'main' into 164-set-alpha-parameter-for-regularization-o…
Marsmaennchen221 Apr 28, 2023
65418d5
docs: added docstring for Exception raised in RidgeRegression
Marsmaennchen221 Apr 28, 2023
a4e904d
Merge branch '164-set-alpha-parameter-for-regularization-of-ridgeregr…
Marsmaennchen221 Apr 28, 2023
fd97926
refactor: changed Error and Warn message for the alpha parameter in `…
Marsmaennchen221 Apr 28, 2023
8375c7d
style: apply automated linter fixes
megalinter-bot Apr 28, 2023
6279d64
Apply suggestions from code review
Marsmaennchen221 Apr 28, 2023
22818e6
Merge branch 'main' into 164-set-alpha-parameter-for-regularization-o…
lars-reimann Apr 28, 2023
d103cb2
test: `alpha` should be passed to sklearn
lars-reimann Apr 28, 2023
31c36a3
refactor: `alpha` attribute to `_alpha`
lars-reimann Apr 28, 2023
8a01930
docs: minor fix in related class
lars-reimann Apr 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@ class ElasticNetRegression(Regressor):
----------
alpha : float
Controls the regularization of the model. The higher the value, the more regularized it becomes.

lasso_ratio: float
Number between 0 and 1 that controls the ratio between Lasso- and Ridge regularization.
lasso_ratio=0 is essentially RidgeRegression
lasso_ratio=1 is essentially LassoRegression
Number between 0 and 1 that controls the ratio between Lasso and Ridge regularization. If 0, only Ridge
regularization is used. If 1, only Lasso regularization is used.

Raises
------
Expand Down
33 changes: 29 additions & 4 deletions src/safeds/ml/classical/regression/_ridge_regression.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

from sklearn.linear_model import Ridge as sk_Ridge
Expand All @@ -13,12 +14,36 @@


class RidgeRegression(Regressor):
"""Ridge regression."""
"""
Ridge regression.

def __init__(self) -> None:
Parameters
----------
alpha : float
Controls the regularization of the model. The higher the value, the more regularized it becomes.

Raises
------
ValueError
If alpha is negative.
"""
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, alpha: float = 1.0) -> None:
self._wrapped_regressor: sk_Ridge | None = None
self._feature_names: list[str] | None = None
self._target_name: str | None = None
self._alpha = alpha
if self._alpha < 0:
raise ValueError("alpha must be non-negative")
if self._alpha == 0.0:
warnings.warn(
(
"Setting alpha to zero makes this model equivalent to LinearRegression. You should use "
"LinearRegression instead for better numerical stability."
),
UserWarning,
stacklevel=2,
)

def fit(self, training_set: TaggedTable) -> RidgeRegression:
"""
Expand All @@ -41,10 +66,10 @@ def fit(self, training_set: TaggedTable) -> RidgeRegression:
LearningError
If the training data contains invalid values or if the training failed.
"""
wrapped_regressor = sk_Ridge()
wrapped_regressor = sk_Ridge(alpha=self._alpha)
fit(wrapped_regressor, training_set)

result = RidgeRegression()
result = RidgeRegression(alpha=self._alpha)
result._wrapped_regressor = wrapped_regressor
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
32 changes: 32 additions & 0 deletions tests/safeds/ml/classical/regression/test_ridge_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.ml.classical.regression import RidgeRegression


def test_should_raise_if_alpha_is_negative() -> None:
with pytest.raises(ValueError, match="alpha must be non-negative"):
RidgeRegression(alpha=-1.0)


def test_should_warn_if_alpha_is_zero() -> None:
with pytest.warns(
UserWarning,
match=(
"Setting alpha to zero makes this model equivalent to LinearRegression. You "
"should use LinearRegression instead for better numerical stability."
),
):
RidgeRegression(alpha=0.0)


def test_should_pass_alpha_to_fitted_regressor() -> None:
regressor = RidgeRegression(alpha=1.0)
fitted_regressor = regressor.fit(Table.from_dict({"A": [1, 2, 4], "B": [1, 2, 3]}).tag_columns("B"))
assert regressor._alpha == fitted_regressor._alpha


def test_should_pass_alpha_to_sklearn() -> None:
regressor = RidgeRegression(alpha=1.0)
fitted_regressor = regressor.fit(Table.from_dict({"A": [1, 2, 4], "B": [1, 2, 3]}).tag_columns("B"))
assert fitted_regressor._wrapped_regressor is not None
assert fitted_regressor._wrapped_regressor.alpha == fitted_regressor._alpha