From 1b37cd64fe84969c6959249c4fca1f58394d703c Mon Sep 17 00:00:00 2001
From: grefrathc <s23cgref@uni-bonn.de>
Date: Fri, 21 Jun 2024 15:44:29 +0200
Subject: [PATCH] issue #750 regularization strength for logistic classifier

---
 .../classification/_logistic_classifier.py    | 11 ++++----
 .../test_logistic_classifier.py               | 26 +++++++++++++++++++
 2 files changed, 32 insertions(+), 5 deletions(-)
 create mode 100644 tests/safeds/ml/classical/classification/test_logistic_classifier.py

diff --git a/src/safeds/ml/classical/classification/_logistic_classifier.py b/src/safeds/ml/classical/classification/_logistic_classifier.py
index e312e6b25..0c8e2856b 100644
--- a/src/safeds/ml/classical/classification/_logistic_classifier.py
+++ b/src/safeds/ml/classical/classification/_logistic_classifier.py
@@ -17,9 +17,9 @@ class LogisticClassifier(Classifier):
     # Dunder methods
     # ------------------------------------------------------------------------------------------------------------------
 
-    def __init__(self) -> None:
+    def __init__(self, c: float=1.0) -> None:
         super().__init__()
-
+        self.c = c
     def __hash__(self) -> int:
         return _structural_hash(
             super().__hash__(),
@@ -30,12 +30,13 @@ def __hash__(self) -> int:
     # ------------------------------------------------------------------------------------------------------------------
 
     def _clone(self) -> LogisticClassifier:
-        return LogisticClassifier()
-
+        return LogisticClassifier(c=self.c)  
+        
     def _get_sklearn_model(self) -> ClassifierMixin:
         from sklearn.linear_model import LogisticRegression as SklearnLogisticRegression
 
         return SklearnLogisticRegression(
             random_state=_get_random_seed(),
             n_jobs=-1,
-        )
+            C=self.c, 
+        )
\ No newline at end of file
diff --git a/tests/safeds/ml/classical/classification/test_logistic_classifier.py b/tests/safeds/ml/classical/classification/test_logistic_classifier.py
new file mode 100644
index 000000000..768beae22
--- /dev/null
+++ b/tests/safeds/ml/classical/classification/test_logistic_classifier.py
@@ -0,0 +1,26 @@
+import pytest
+from safeds.data.labeled.containers import TabularDataset
+from safeds.data.tabular.containers import Table
+from safeds.ml.classical.classification import LogisticClassifier
+
+
+@pytest.fixture()
+def training_set() -> TabularDataset:
+    table = Table({"col1": [1, 2, 3, 4], "col2": [1, 2, 3, 4]})
+    return table.to_tabular_dataset(target_name="col1")
+
+class TestC: 
+    def test_should_be_passed_to_fitted_model(self, training_set: TabularDataset) -> None:
+        fitted_model = LogisticClassifier(c=2).fit(training_set)
+        assert fitted_model.c == 2
+
+    def test_should_be_passed_to_sklearn(self, training_set: TabularDataset) -> None:
+        fitted_model = LogisticClassifier(c=2).fit(training_set)
+        assert fitted_model._wrapped_model is not None
+        assert fitted_model._wrapped_model.C == 2
+
+    def test_clone(self, training_set: TabularDataset) -> None:
+        fitted_model = LogisticClassifier(c=2).fit(training_set)
+        cloned_classifier = fitted_model._clone()
+        assert isinstance(cloned_classifier, LogisticClassifier)
+        assert cloned_classifier.c == fitted_model.c
\ No newline at end of file