From 1b37cd64fe84969c6959249c4fca1f58394d703c Mon Sep 17 00:00:00 2001 From: grefrathc <s23cgref@uni-bonn.de> Date: Fri, 21 Jun 2024 15:44:29 +0200 Subject: [PATCH] issue #750 regularization strength for logistic classifier --- .../classification/_logistic_classifier.py | 11 ++++---- .../test_logistic_classifier.py | 26 +++++++++++++++++++ 2 files changed, 32 insertions(+), 5 deletions(-) create mode 100644 tests/safeds/ml/classical/classification/test_logistic_classifier.py diff --git a/src/safeds/ml/classical/classification/_logistic_classifier.py b/src/safeds/ml/classical/classification/_logistic_classifier.py index e312e6b25..0c8e2856b 100644 --- a/src/safeds/ml/classical/classification/_logistic_classifier.py +++ b/src/safeds/ml/classical/classification/_logistic_classifier.py @@ -17,9 +17,9 @@ class LogisticClassifier(Classifier): # Dunder methods # ------------------------------------------------------------------------------------------------------------------ - def __init__(self) -> None: + def __init__(self, c: float=1.0) -> None: super().__init__() - + self.c = c def __hash__(self) -> int: return _structural_hash( super().__hash__(), @@ -30,12 +30,13 @@ def __hash__(self) -> int: # ------------------------------------------------------------------------------------------------------------------ def _clone(self) -> LogisticClassifier: - return LogisticClassifier() - + return LogisticClassifier(c=self.c) + def _get_sklearn_model(self) -> ClassifierMixin: from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression return SklearnLogisticRegression( random_state=_get_random_seed(), n_jobs=-1, - ) + C=self.c, + ) \ No newline at end of file diff --git a/tests/safeds/ml/classical/classification/test_logistic_classifier.py b/tests/safeds/ml/classical/classification/test_logistic_classifier.py new file mode 100644 index 000000000..768beae22 --- /dev/null +++ b/tests/safeds/ml/classical/classification/test_logistic_classifier.py @@ -0,0 +1,26 @@ +import pytest +from safeds.data.labeled.containers import TabularDataset +from safeds.data.tabular.containers import Table +from safeds.ml.classical.classification import LogisticClassifier + + +@pytest.fixture() +def training_set() -> TabularDataset: + table = Table({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]}) + return table.to_tabular_dataset(target_name="col1") + +class TestC: + def test_should_be_passed_to_fitted_model(self, training_set: TabularDataset) -> None: + fitted_model = LogisticClassifier(c=2).fit(training_set) + assert fitted_model.c == 2 + + def test_should_be_passed_to_sklearn(self, training_set: TabularDataset) -> None: + fitted_model = LogisticClassifier(c=2).fit(training_set) + assert fitted_model._wrapped_model is not None + assert fitted_model._wrapped_model.C == 2 + + def test_clone(self, training_set: TabularDataset) -> None: + fitted_model = LogisticClassifier(c=2).fit(training_set) + cloned_classifier = fitted_model._clone() + assert isinstance(cloned_classifier, LogisticClassifier) + assert cloned_classifier.c == fitted_model.c \ No newline at end of file