Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added alpha parameter to RidgeRegression #231

Merged
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
9beea27
feat: Added alpha parameter to `RidgeRegression`
Marsmaennchen221 Apr 21, 2023
1717416
Merge branch 'main' of https://github.com/Safe-DS/Stdlib into 164-set…
Marsmaennchen221 Apr 21, 2023
182455c
refactor: Removed temporary test and print statements
Marsmaennchen221 Apr 21, 2023
ae9acde
refactor: Provided more information for the ruff linter to pass
Marsmaennchen221 Apr 21, 2023
4e10ebf
style: apply automated linter fixes
megalinter-bot Apr 21, 2023
4de039c
style: apply automated linter fixes
megalinter-bot Apr 21, 2023
8b9f396
Merge branch 'main' into 164-set-alpha-parameter-for-regularization-o…
lars-reimann Apr 25, 2023
a5f99ef
style: apply automated linter fixes
megalinter-bot Apr 25, 2023
a768484
test: Added match parameter and expected Exception to test
Marsmaennchen221 Apr 28, 2023
42db542
style: apply automated linter fixes
megalinter-bot Apr 28, 2023
e3b1940
Merge branch 'main' into 164-set-alpha-parameter-for-regularization-o…
Marsmaennchen221 Apr 28, 2023
65418d5
docs: added docstring for Exception raised in RidgeRegression
Marsmaennchen221 Apr 28, 2023
a4e904d
Merge branch '164-set-alpha-parameter-for-regularization-of-ridgeregr…
Marsmaennchen221 Apr 28, 2023
fd97926
refactor: changed Error and Warn message for the alpha parameter in `…
Marsmaennchen221 Apr 28, 2023
8375c7d
style: apply automated linter fixes
megalinter-bot Apr 28, 2023
6279d64
Apply suggestions from code review
Marsmaennchen221 Apr 28, 2023
22818e6
Merge branch 'main' into 164-set-alpha-parameter-for-regularization-o…
lars-reimann Apr 28, 2023
d103cb2
test: `alpha` should be passed to sklearn
lars-reimann Apr 28, 2023
31c36a3
refactor: `alpha` attribute to `_alpha`
lars-reimann Apr 28, 2023
8a01930
docs: minor fix in related class
lars-reimann Apr 28, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 29 additions & 4 deletions src/safeds/ml/classical/regression/_ridge_regression.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING

from sklearn.linear_model import Ridge as sk_Ridge
Expand All @@ -13,12 +14,36 @@


class RidgeRegression(Regressor):
"""Ridge regression."""
"""
Ridge regression.

def __init__(self) -> None:
Parameters
----------
alpha : float
Controls the regularization strength. Has to be greater than or equal to 0.
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved

Raises
------
ValueError
If alpha is negative.
"""
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, alpha: float = 1.0) -> None:
self._wrapped_regressor: sk_Ridge | None = None
self._feature_names: list[str] | None = None
self._target_name: str | None = None
self.alpha = alpha
if self.alpha < 0:
raise ValueError("alpha must be positive")
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved
if self.alpha == 0.0:
warnings.warn(
(
"Setting alpha to zero makes this model equivalent to LinearRegression. You should use"
" LinearRegression instead for better numerical stability."
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved
),
UserWarning,
stacklevel=2,
)

def fit(self, training_set: TaggedTable) -> RidgeRegression:
"""
Expand All @@ -41,10 +66,10 @@ def fit(self, training_set: TaggedTable) -> RidgeRegression:
LearningError
If the training data contains invalid values or if the training failed.
"""
wrapped_regressor = sk_Ridge()
wrapped_regressor = sk_Ridge(alpha=self.alpha)
fit(wrapped_regressor, training_set)

result = RidgeRegression()
result = RidgeRegression(alpha=self.alpha)
result._wrapped_regressor = wrapped_regressor
result._feature_names = training_set.features.column_names
result._target_name = training_set.target.name
Expand Down
25 changes: 25 additions & 0 deletions tests/safeds/ml/classical/regression/test_ridge_regression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import pytest
from safeds.data.tabular.containers import Table
from safeds.ml.classical.regression import RidgeRegression


def test_ridge_regression_invalid() -> None:
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(ValueError, match="alpha must be positive"):
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved
RidgeRegression(alpha=-1.0)


def test_ridge_regression_warning() -> None:
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved
with pytest.warns(
UserWarning,
match=(
"Setting alpha to zero makes this model equivalent to LinearRegression. You should use LinearRegression"
" instead for better numerical stability."
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved
),
):
RidgeRegression(alpha=0.0)


def test_ridge_regression() -> None:
Marsmaennchen221 marked this conversation as resolved.
Show resolved Hide resolved
regression = RidgeRegression(alpha=1.0)
fitted_regression = regression.fit(Table.from_dict({"A": [1, 2, 4], "B": [1, 2, 3]}).tag_columns("B"))
assert regression.alpha == fitted_regression.alpha