From 5dca9503f0707c947f2b35ddd01e0a91359e3d2d Mon Sep 17 00:00:00 2001 From: Scott Date: Sun, 7 Mar 2021 21:06:54 -0600 Subject: [PATCH 01/17] Add loss='cce' to KerasClassifier --- scikeras/wrappers.py | 11 ++++++++++- tests/test_loss_auto.py | 24 +++++++++--------------- 2 files changed, 19 insertions(+), 16 deletions(-) diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index a186cd29..5e66c163 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -1328,7 +1328,7 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: raise ValueError( 'Only single-output models are supported with `loss="auto"`' ) - if self.target_type_ == "binary": + elif self.target_type_ == "binary": compile_kwargs["loss"] = "binary_crossentropy" elif self.target_type_ == "multiclass": if self.model_.outputs[0].shape[1] == 1: @@ -1336,6 +1336,15 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: f"Multi-class targets require the model to have >1 output unit instead of {self.model_.outputs[0].shape} units" ) compile_kwargs["loss"] = "sparse_categorical_crossentropy" + elif hasattr(self, "n_classes_"): + n_out = self.model_.outputs[0].shape[1] + if n_out != self.n_classes_: + raise ValueError( + "loss='categorical_crossentropy' is expecting the model " + f"to have {self.n_classes_} output neurons, one for each " + "class. However, only {n_out} output neurons were found" + ) + compile_kwargs["loss"] = "categorical_crossentropy" else: raise NotImplementedError( f'`loss="auto"` is not supported for tasks of type {self.target_type_}.' diff --git a/tests/test_loss_auto.py b/tests/test_loss_auto.py index 18138593..dd6097f3 100644 --- a/tests/test_loss_auto.py +++ b/tests/test_loss_auto.py @@ -99,15 +99,15 @@ def test_classifier_unsupported_multi_output_tasks(use_case): @pytest.mark.parametrize( - "use_case,supported", + "use_case", [ - ("binary_classification", True), - ("binary_classification_w_one_class", True), - ("classification_w_1d_targets", True), - ("classification_w_onehot_targets", False), + "binary_classification", + "binary_classification_w_one_class", + "classification_w_1d_targets", + "classification_w_onehot_targets", ], ) -def test_classifier_default_loss_only_model_specified(use_case, supported): +def test_classifier_default_loss_only_model_specified(use_case): """Test that KerasClassifier will auto-determine a loss function when only the model is specified. """ @@ -123,20 +123,14 @@ def test_classifier_default_loss_only_model_specified(use_case, supported): exp_loss = "sparse_categorical_crossentropy" y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int) elif use_case == "classification_w_onehot_targets": + exp_loss = "categorical_crossentropy" y = np.random.choice(N_CLASSES, size=len(X)).astype(int) y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1)) est = KerasClassifier(model=shallow_net, model__single_output=model__single_output) - if supported: - est.fit(X, y=y) - assert loss_name(est.model_.loss) == exp_loss - else: - with pytest.raises( - NotImplementedError, - match='`loss="auto"` is not supported for tasks of type', - ): - est.fit(X, y=y) + est.fit(X, y=y) + assert loss_name(est.model_.loss) == exp_loss assert est.loss == "auto" From 3be3c26505b646eaff54d789f85d95037e39ee2d Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Sun, 7 Mar 2021 21:17:06 -0600 Subject: [PATCH 02/17] Update scikeras/wrappers.py --- scikeras/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index 07e15c63..d3739fd7 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -1346,7 +1346,7 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: raise ValueError( "loss='categorical_crossentropy' is expecting the model " f"to have {self.n_classes_} output neurons, one for each " - "class. However, only {n_out} output neurons were found" + f"class. However, only {n_out} output neurons were found" ) compile_kwargs["loss"] = "categorical_crossentropy" else: From e7b4368daefc974f5104a314b130226f251b6850 Mon Sep 17 00:00:00 2001 From: Scott Sievert Date: Sun, 7 Mar 2021 21:22:41 -0600 Subject: [PATCH 03/17] Update scikeras/wrappers.py Co-authored-by: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> --- scikeras/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index d3739fd7..a0c76b46 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -1340,7 +1340,7 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: "Multi-class targets require the model to have >1 output units." ) compile_kwargs["loss"] = "sparse_categorical_crossentropy" - elif hasattr(self, "n_classes_"): + elif self.target_type_ == "multilabel-indicator" and hasattr(self, "n_classes_"): n_out = self.model_.outputs[0].shape[1] if n_out != self.n_classes_: raise ValueError( From a43768e46a0b8862df2d99c1e3b3ad27bb4f9a5b Mon Sep 17 00:00:00 2001 From: Scott Date: Sun, 7 Mar 2021 21:28:30 -0600 Subject: [PATCH 04/17] Catch valuerror --- tests/test_loss_auto.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/tests/test_loss_auto.py b/tests/test_loss_auto.py index dd6097f3..af25d873 100644 --- a/tests/test_loss_auto.py +++ b/tests/test_loss_auto.py @@ -14,13 +14,13 @@ X = np.random.uniform(size=(n_eg, FEATURES)).astype("float32") -def shallow_net(single_output=False, loss=None, compile=False): +def shallow_net(outputs=None, loss=None, compile=False): model = tf.keras.Sequential() model.add(tf.keras.layers.Input(shape=(FEATURES,))) - if single_output: - model.add(tf.keras.layers.Dense(1)) - else: + if outputs is None: model.add(tf.keras.layers.Dense(N_CLASSES)) + else: + model.add(tf.keras.layers.Dense(outputs)) if compile: model.compile(loss=loss) @@ -45,7 +45,7 @@ def test_user_compiled(loss): """Test to make sure that user compiled classification models work with all classification losses. """ - model__single_output = True if "binary" in loss else False + model__outputs = 1 if "binary" in loss else None if loss == "binary_crosentropy": y = np.random.randint(0, 2, size=(n_eg,)) elif loss == "categorical_crossentropy": @@ -59,7 +59,7 @@ def test_user_compiled(loss): shallow_net, model__compile=True, model__loss=loss, - model__single_output=model__single_output, + model__outputs=model__outputs, ) est.partial_fit(X, y) @@ -112,7 +112,7 @@ def test_classifier_default_loss_only_model_specified(use_case): when only the model is specified. """ - model__single_output = True if "binary" in use_case else False + model__outputs = 1 if "binary" in use_case else None if use_case == "binary_classification": exp_loss = "binary_crossentropy" y = np.random.choice(2, size=len(X)).astype(int) @@ -127,7 +127,7 @@ def test_classifier_default_loss_only_model_specified(use_case): y = np.random.choice(N_CLASSES, size=len(X)).astype(int) y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1)) - est = KerasClassifier(model=shallow_net, model__single_output=model__single_output) + est = KerasClassifier(model=shallow_net, model__outputs=model__outputs) est.fit(X, y=y) assert loss_name(est.model_.loss) == exp_loss @@ -142,7 +142,7 @@ def test_regressor_default_loss_only_model_specified(use_case): y = np.random.uniform(size=len(X)) if use_case == "multi_output": y = np.column_stack([y, y]) - est = KerasRegressor(model=shallow_net, model__single_output=True) + est = KerasRegressor(model=shallow_net, model__outputs=1 if "single" in use_case else 2) est.fit(X, y) assert est.loss == "auto" assert loss_name(est.model_.loss) == "mean_squared_error" @@ -191,3 +191,13 @@ def test_multi_output_support(user_compiled, est_cls): match='Only single-output models are supported with `loss="auto"`', ): est.fit(X, y) + +def test_catch_bad_model_with_auto_loss_categorical_crossentropy(): + exp_loss = "categorical_crossentropy" + y = np.random.choice(N_CLASSES, size=len(X)).astype(int) + y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1)) + + est = KerasClassifier(model=shallow_net, model__outputs=N_CLASSES - 1) + msg = "loss='categorical_crossentropy' is expecting the model to have {N_CLASSES} output neurons" + with pytest.raises(ValueError, match=msg): + est.initialize(X, y=y) From 12139de9f653674e4649cc88180a73b9800254a6 Mon Sep 17 00:00:00 2001 From: Scott Date: Sun, 7 Mar 2021 21:30:00 -0600 Subject: [PATCH 05/17] black --- tests/test_loss_auto.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_loss_auto.py b/tests/test_loss_auto.py index 9730bd60..5686b8d1 100644 --- a/tests/test_loss_auto.py +++ b/tests/test_loss_auto.py @@ -69,7 +69,7 @@ def test_user_compiled(loss): class NoEncoderClf(KerasClassifier): """A classifier overriding default target encoding. - This simulates a user implementing custom encoding logic in + This simulates a user implementing custom encoding logic in target_encoder to support multiclass-multioutput or multilabel-indicator, which by default would raise an error. """ @@ -142,7 +142,9 @@ def test_regressor_default_loss_only_model_specified(use_case): y = np.random.uniform(size=len(X)) if use_case == "multi_output": y = np.column_stack([y, y]) - est = KerasRegressor(model=shallow_net, model__outputs=1 if "single" in use_case else 2) + est = KerasRegressor( + model=shallow_net, model__outputs=1 if "single" in use_case else 2 + ) est.fit(X, y) assert est.loss == "auto" assert loss_name(est.model_.loss) == "mean_squared_error" @@ -192,6 +194,7 @@ def test_multi_output_support(user_compiled, est_cls): ): est.fit(X, y) + def test_catch_bad_model_with_auto_loss_categorical_crossentropy(): exp_loss = "categorical_crossentropy" y = np.random.choice(N_CLASSES, size=len(X)).astype(int) @@ -202,6 +205,7 @@ def test_catch_bad_model_with_auto_loss_categorical_crossentropy(): with pytest.raises(ValueError, match=msg): est.initialize(X, y=y) + def test_multiclass_single_output_unit(): """Test that multiclass targets requires > 1 output units. """ From fc9e7ee98a8c4eec95c6f68b3629e31ce3874131 Mon Sep 17 00:00:00 2001 From: Scott Date: Sun, 7 Mar 2021 21:31:21 -0600 Subject: [PATCH 06/17] black --- scikeras/wrappers.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index a0c76b46..47d45e10 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -1340,7 +1340,9 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: "Multi-class targets require the model to have >1 output units." ) compile_kwargs["loss"] = "sparse_categorical_crossentropy" - elif self.target_type_ == "multilabel-indicator" and hasattr(self, "n_classes_"): + elif self.target_type_ == "multilabel-indicator" and hasattr( + self, "n_classes_" + ): n_out = self.model_.outputs[0].shape[1] if n_out != self.n_classes_: raise ValueError( From 24d721fcb4c68d81e2a627897a90a0c7800b1d4c Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 8 Mar 2021 14:21:44 -0600 Subject: [PATCH 07/17] warn and do not infer for multilabel-indicator --- scikeras/wrappers.py | 35 ++++++++++++++++--------------- tests/test_loss_auto.py | 46 ++++++++++++++++++++--------------------- 2 files changed, 40 insertions(+), 41 deletions(-) diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index 47d45e10..4fd4c99b 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -22,7 +22,6 @@ from tensorflow.keras import optimizers as optimizers_module from tensorflow.keras.models import Model from tensorflow.keras.utils import register_keras_serializable -from tensorflow.python.types.core import Value from scikeras._utils import ( TFRandomState, @@ -1328,35 +1327,37 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: raise ValueError( 'Only single-output models are supported with `loss="auto"`' ) + loss = None + hint = "" if self.target_type_ == "binary": if self.model_.outputs[0].shape[1] != 1: raise ValueError( "Binary classification expects a model with exactly 1 output unit." ) - compile_kwargs["loss"] = "binary_crossentropy" + loss = "binary_crossentropy" elif self.target_type_ == "multiclass": if self.model_.outputs[0].shape[1] == 1: raise ValueError( "Multi-class targets require the model to have >1 output units." ) - compile_kwargs["loss"] = "sparse_categorical_crossentropy" - elif self.target_type_ == "multilabel-indicator" and hasattr( - self, "n_classes_" - ): - n_out = self.model_.outputs[0].shape[1] - if n_out != self.n_classes_: - raise ValueError( - "loss='categorical_crossentropy' is expecting the model " - f"to have {self.n_classes_} output neurons, one for each " - f"class. However, only {n_out} output neurons were found" - ) - compile_kwargs["loss"] = "categorical_crossentropy" - else: - raise NotImplementedError( + loss = "sparse_categorical_crossentropy" + elif self.target_type_ == "multilabel-indicator": + # one-hot encoded multiclass problem OR multilabel-indicator problem + hint = ( + "For this type of problem, the following may help:" + '\n - If there is only one class per example, loss="categorical_crossentropy" might be appropriate.' + '\n - If there are multiple classes per example, loss="binary_crossentropy" might be appropriate.' + ) + if loss is None: + msg = ( f'`loss="auto"` is not supported for tasks of type {self.target_type_}.' - " Instead, you must explicitly pass a loss function, for example:" + "\nInstead, you must explicitly pass a loss function, for example:" '\n clf = KerasClassifier(..., loss="categorical_crossentropy")' ) + if hint: + msg += f"\n{hint}" + raise NotImplementedError(msg) + compile_kwargs["loss"] = loss self.model_.compile(**compile_kwargs) @staticmethod diff --git a/tests/test_loss_auto.py b/tests/test_loss_auto.py index 5686b8d1..c7eb9c15 100644 --- a/tests/test_loss_auto.py +++ b/tests/test_loss_auto.py @@ -79,11 +79,19 @@ def target_encoder(self): return FunctionTransformer() -@pytest.mark.parametrize("use_case", ["multilabel-indicator", "multiclass-multioutput"]) -def test_classifier_unsupported_multi_output_tasks(use_case): +@pytest.mark.parametrize( + "use_case,wrapper_cls", + [ + ("multilabel-indicator", NoEncoderClf), + ("multiclass-multioutput", NoEncoderClf), + ("classification_w_onehot_targets", KerasClassifier), + ], +) +def test_classifier_unsupported_multi_output_tasks(use_case, wrapper_cls): """Test for an appropriate error for tasks that are not supported by `loss="auto"`. """ + extra = "" if use_case == "multiclass-multioutput": y1 = np.random.randint(0, 1, size=len(X)) y2 = np.random.randint(0, 2, size=len(X)) @@ -91,10 +99,16 @@ def test_classifier_unsupported_multi_output_tasks(use_case): elif use_case == "multilabel-indicator": y1 = np.random.randint(0, 1, size=len(X)) y = np.column_stack([y1, y1]) - est = NoEncoderClf(shallow_net, model__compile=False) - with pytest.raises( - NotImplementedError, match='`loss="auto"` is not supported for tasks of type' - ): + extra = 'loss="binary_crossentropy" might be appropriate' + elif use_case == "classification_w_onehot_targets": + y = np.random.choice(N_CLASSES, size=len(X)).astype(int) + y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1)) + extra = 'loss="categorical_crossentropy" might be appropriate' + est = wrapper_cls(shallow_net, model__compile=False) + match = '`loss="auto"` is not supported for tasks of type' + if extra: + match += f"(.|\n)+{extra}" + with pytest.raises(NotImplementedError, match=match): est.initialize(X, y) @@ -104,7 +118,6 @@ def test_classifier_unsupported_multi_output_tasks(use_case): "binary_classification", "binary_classification_w_one_class", "classification_w_1d_targets", - "classification_w_onehot_targets", ], ) def test_classifier_default_loss_only_model_specified(use_case): @@ -122,10 +135,6 @@ def test_classifier_default_loss_only_model_specified(use_case): elif use_case == "classification_w_1d_targets": exp_loss = "sparse_categorical_crossentropy" y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int) - elif use_case == "classification_w_onehot_targets": - exp_loss = "categorical_crossentropy" - y = np.random.choice(N_CLASSES, size=len(X)).astype(int) - y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1)) est = KerasClassifier(model=shallow_net, model__outputs=model__outputs) @@ -195,21 +204,10 @@ def test_multi_output_support(user_compiled, est_cls): est.fit(X, y) -def test_catch_bad_model_with_auto_loss_categorical_crossentropy(): - exp_loss = "categorical_crossentropy" - y = np.random.choice(N_CLASSES, size=len(X)).astype(int) - y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1)) - - est = KerasClassifier(model=shallow_net, model__outputs=N_CLASSES - 1) - msg = "loss='categorical_crossentropy' is expecting the model to have {N_CLASSES} output neurons" - with pytest.raises(ValueError, match=msg): - est.initialize(X, y=y) - - def test_multiclass_single_output_unit(): """Test that multiclass targets requires > 1 output units. """ - est = KerasClassifier(model=shallow_net, model__single_output=True) + est = KerasClassifier(model=shallow_net, model__outputs=1) y = np.random.choice(N_CLASSES, size=(len(X), 1)).astype(int) with pytest.raises( ValueError, @@ -221,7 +219,7 @@ def test_multiclass_single_output_unit(): def test_binary_multiple_output_units(): """Test that binary targets requires exactly 1 output unit. """ - est = KerasClassifier(model=shallow_net, model__single_output=False) + est = KerasClassifier(model=shallow_net, model__outputs=2) y = np.random.choice(2, size=len(X)).astype(int) with pytest.raises( ValueError, From fc81a0ecfc63d99c5cc179c3a214cb9e0d055061 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 8 Mar 2021 20:23:29 -0800 Subject: [PATCH 08/17] Update scikeras/wrappers.py Co-authored-by: Scott Sievert --- scikeras/wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index 4fd4c99b..017e26ac 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -1355,7 +1355,7 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: '\n clf = KerasClassifier(..., loss="categorical_crossentropy")' ) if hint: - msg += f"\n{hint}" + msg += f"\n\n{hint}" raise NotImplementedError(msg) compile_kwargs["loss"] = loss self.model_.compile(**compile_kwargs) From 036c26d43200ad96220f83669a95178a986d0d2a Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 8 Mar 2021 20:24:51 -0800 Subject: [PATCH 09/17] Update tests/test_loss_auto.py --- tests/test_loss_auto.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_loss_auto.py b/tests/test_loss_auto.py index c7eb9c15..6996112d 100644 --- a/tests/test_loss_auto.py +++ b/tests/test_loss_auto.py @@ -99,6 +99,7 @@ def test_classifier_unsupported_multi_output_tasks(use_case, wrapper_cls): elif use_case == "multilabel-indicator": y1 = np.random.randint(0, 1, size=len(X)) y = np.column_stack([y1, y1]) + y[0, :] = 1 extra = 'loss="binary_crossentropy" might be appropriate' elif use_case == "classification_w_onehot_targets": y = np.random.choice(N_CLASSES, size=len(X)).astype(int) From 552ad1c9da00a299503b6e59b3e3f4e9f055b16f Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Mon, 8 Mar 2021 22:37:42 -0600 Subject: [PATCH 10/17] pr feedbackg --- scikeras/wrappers.py | 2 +- tests/test_loss_auto.py | 12 ++++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/scikeras/wrappers.py b/scikeras/wrappers.py index 017e26ac..1adbeaea 100644 --- a/scikeras/wrappers.py +++ b/scikeras/wrappers.py @@ -1351,7 +1351,7 @@ def _compile_model(self, compile_kwargs: Dict[str, Any]) -> None: if loss is None: msg = ( f'`loss="auto"` is not supported for tasks of type {self.target_type_}.' - "\nInstead, you must explicitly pass a loss function, for example:" + "\nInstead, you must compile the model yourself or explicitly pass a loss function, for example:" '\n clf = KerasClassifier(..., loss="categorical_crossentropy")' ) if hint: diff --git a/tests/test_loss_auto.py b/tests/test_loss_auto.py index 6996112d..35f12efe 100644 --- a/tests/test_loss_auto.py +++ b/tests/test_loss_auto.py @@ -92,6 +92,7 @@ def test_classifier_unsupported_multi_output_tasks(use_case, wrapper_cls): by `loss="auto"`. """ extra = "" + fix_loss = None if use_case == "multiclass-multioutput": y1 = np.random.randint(0, 1, size=len(X)) y2 = np.random.randint(0, 2, size=len(X)) @@ -100,17 +101,20 @@ def test_classifier_unsupported_multi_output_tasks(use_case, wrapper_cls): y1 = np.random.randint(0, 1, size=len(X)) y = np.column_stack([y1, y1]) y[0, :] = 1 - extra = 'loss="binary_crossentropy" might be appropriate' + fix_loss = "binary_crossentropy" + extra = f'loss="{fix_loss}" might be appropriate' elif use_case == "classification_w_onehot_targets": y = np.random.choice(N_CLASSES, size=len(X)).astype(int) y = OneHotEncoder(sparse=False).fit_transform(y.reshape(-1, 1)) - extra = 'loss="categorical_crossentropy" might be appropriate' - est = wrapper_cls(shallow_net, model__compile=False) + fix_loss = "categorical_crossentropy" + extra = f'loss="{fix_loss}" might be appropriate' match = '`loss="auto"` is not supported for tasks of type' if extra: match += f"(.|\n)+{extra}" with pytest.raises(NotImplementedError, match=match): - est.initialize(X, y) + wrapper_cls(shallow_net, model__compile=False).initialize(X, y) + if fix_loss: + wrapper_cls(shallow_net, model__compile=False, loss=fix_loss).initialize(X, y) @pytest.mark.parametrize( From 61f3ee17363361c35f0b59659776411dcea75baa Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 30 Mar 2021 22:01:49 -0500 Subject: [PATCH 11/17] documentation --- docs/source/advanced.rst | 34 +++++++++++++++++++++++++++++++--- docs/source/quickstart.rst | 5 ++++- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index c771d394..15168146 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -1,6 +1,6 @@ -=================================== -Advanced Usage of SciKeras Wrappers -=================================== +============== +Advanced Usage +============== Wrapper Classes --------------- @@ -128,6 +128,34 @@ offer an easy way to compile and tune compilation parameters. Examples: In all cases, returning an un-compiled model is equivalent to calling ``model.compile(**compile_kwargs)`` within ``model_build_fn``. +Loss selection +++++++++++++++ + +If you do not explicitly define a loss, SciKeras attempts to find a loss +that matches the type of target (see :py:func:`sklearn.utils.multiclass.type_of_target`). +Losses are selected as follows: + +Classification +.............. + ++-----------+-----------+----------+---------------------------------+ +| # outputs | # classes | encoding | loss | ++===========+===========+==========+=================================+ +| 1 | <= 2 | any | binary crossentropy | ++-----------+-----------+----------+---------------------------------+ +| 1 | >=2 | labels | sparse categorical crossentropy | ++-----------+-----------+----------+---------------------------------+ +| 1 | >=2 | one-hot | unsupported | ++-----------+-----------+----------+---------------------------------+ +| > 1 | -- | -- | unsupported | ++-----------+-----------+----------+---------------------------------+ + +Regression +.......... + +Regression always defaults to mean squared error. +For multi-output models, Keras will use the sum of each output's loss. + Arguments to ``model_build_fn`` ------------------------------- diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 8555ea29..0cfbd42c 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -40,7 +40,6 @@ it on a toy classification dataset using SciKeras clf = KerasClassifier( get_model, - loss="sparse_categorical_crossentropy", hidden_layer_dim=100, ) @@ -48,6 +47,10 @@ it on a toy classification dataset using SciKeras y_proba = clf.predict_proba(X) +Note that SciKeras even chooses a loss function and compiles your model +(see the :ref:`Advanced Usage` section for more details)! + + In an sklearn Pipeline ---------------------- From 01fe3b7ba9abeb27e5f824f36b72ae0781945040 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 30 Mar 2021 22:02:56 -0500 Subject: [PATCH 12/17] newline --- docs/source/quickstart.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 0cfbd42c..4621440f 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -50,7 +50,6 @@ it on a toy classification dataset using SciKeras Note that SciKeras even chooses a loss function and compiles your model (see the :ref:`Advanced Usage` section for more details)! - In an sklearn Pipeline ---------------------- From a91cf218193905fc62d9fead563c13a2466b8d6b Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 30 Mar 2021 22:03:17 -0500 Subject: [PATCH 13/17] newline --- docs/source/advanced.rst | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 15168146..042be2be 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -156,7 +156,6 @@ Regression Regression always defaults to mean squared error. For multi-output models, Keras will use the sum of each output's loss. - Arguments to ``model_build_fn`` ------------------------------- From 9142bdfd6ccf76d6e0f5fc5afb482ff9dbc8c7ae Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Tue, 30 Mar 2021 22:09:52 -0500 Subject: [PATCH 14/17] add links to resources --- docs/source/advanced.rst | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 042be2be..934a3440 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -133,7 +133,19 @@ Loss selection If you do not explicitly define a loss, SciKeras attempts to find a loss that matches the type of target (see :py:func:`sklearn.utils.multiclass.type_of_target`). -Losses are selected as follows: + +To override the default loss, simply specify a loss function: + +.. code-block:: diff + + -KerasClassifier(model=model_build_fn) + +KerasClassifier(model=model_build_fn, loss="categorical_crossentropy") + +For guidance selecting losses in Keras, please see Jason Brownlee's +excelent article `How to Choose Loss Functions When Training Deep Learning Neural Networks`_ +as well as `Keras Losses docs`_. + +Default losses are selected as follows: Classification .............. @@ -314,3 +326,7 @@ and :class:`scikeras.wrappers.KerasRegressor` respectively. To override these sc .. _Keras Callbacks docs: https://www.tensorflow.org/api_docs/python/tf/keras/callbacks .. _Keras Metrics docs: https://www.tensorflow.org/api_docs/python/tf/keras/metrics + +.. _Keras Losses docs: https://www.tensorflow.org/api_docs/python/tf/keras/losses + +.. _How to Choose Loss Functions When Training Deep Learning Neural Networks: https://machinelearningmastery.com/how-to-choose-loss-functions-when-training-deep-learning-neural-networks/ \ No newline at end of file From a6c250910234d53a2275e65116ce1e69c1fe2530 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 31 Mar 2021 08:12:07 -0500 Subject: [PATCH 15/17] style and typos --- docs/source/advanced.rst | 9 +-------- docs/source/quickstart.rst | 14 ++++++++------ 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index 934a3440..e9f318d1 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -134,15 +134,8 @@ Loss selection If you do not explicitly define a loss, SciKeras attempts to find a loss that matches the type of target (see :py:func:`sklearn.utils.multiclass.type_of_target`). -To override the default loss, simply specify a loss function: - -.. code-block:: diff - - -KerasClassifier(model=model_build_fn) - +KerasClassifier(model=model_build_fn, loss="categorical_crossentropy") - For guidance selecting losses in Keras, please see Jason Brownlee's -excelent article `How to Choose Loss Functions When Training Deep Learning Neural Networks`_ +excellent article `How to Choose Loss Functions When Training Deep Learning Neural Networks`_ as well as `Keras Losses docs`_. Default losses are selected as follows: diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 4621440f..3788582a 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -38,17 +38,19 @@ it on a toy classification dataset using SciKeras model.add(keras.layers.Activation("softmax")) return model - clf = KerasClassifier( - get_model, - hidden_layer_dim=100, - ) + clf = KerasClassifier(get_model, hidden_layer_dim=100) clf.fit(X, y) y_proba = clf.predict_proba(X) -Note that SciKeras even chooses a loss function and compiles your model -(see the :ref:`Advanced Usage` section for more details)! +Note that SciKeras even chooses a loss function and compiles your model. +To override the default loss, simply specify a loss function: + +.. code-block:: diff + + -KerasClassifier(get_model, hidden_layer_dim=100) + +KerasClassifier(get_model, loss="categorical_crossentropy") In an sklearn Pipeline ---------------------- From 1c2b9c38841f10bb6c84280be8426ec5d4fab933 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 31 Mar 2021 08:43:18 -0500 Subject: [PATCH 16/17] Add detail --- docs/source/advanced.rst | 2 ++ docs/source/quickstart.rst | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index e9f318d1..bf36c63f 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -128,6 +128,8 @@ offer an easy way to compile and tune compilation parameters. Examples: In all cases, returning an un-compiled model is equivalent to calling ``model.compile(**compile_kwargs)`` within ``model_build_fn``. +.. _loss-selection: + Loss selection ++++++++++++++ diff --git a/docs/source/quickstart.rst b/docs/source/quickstart.rst index 3788582a..66ba19b1 100644 --- a/docs/source/quickstart.rst +++ b/docs/source/quickstart.rst @@ -52,6 +52,11 @@ To override the default loss, simply specify a loss function: -KerasClassifier(get_model, hidden_layer_dim=100) +KerasClassifier(get_model, loss="categorical_crossentropy") +In this case, you would need to specify the loss since SciKeras +will not default to categorical crossentropy, even for one-hot +encoded targets. +See :ref:`loss-selection` for more details. + In an sklearn Pipeline ---------------------- From 671d5ba2deab74703209d295fc9fbcc188792653 Mon Sep 17 00:00:00 2001 From: Adrian Garcia Badaracco <1755071+adriangb@users.noreply.github.com> Date: Wed, 31 Mar 2021 14:42:54 -0500 Subject: [PATCH 17/17] Add note on cce --- docs/source/advanced.rst | 3 +++ pyproject.toml | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/source/advanced.rst b/docs/source/advanced.rst index bf36c63f..624b6357 100644 --- a/docs/source/advanced.rst +++ b/docs/source/advanced.rst @@ -157,6 +157,9 @@ Classification | > 1 | -- | -- | unsupported | +-----------+-----------+----------+---------------------------------+ +Note that SciKeras will not automatically infer the loss for one-hot encoded targets, +you would need to explicitly specify `loss="categorical_crossentropy"`. + Regression .......... diff --git a/pyproject.toml b/pyproject.toml index 4245a16c..01e5f126 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,7 @@ version = "0.2.1" [tool.poetry.dependencies] importlib-metadata = {version = "^3.4.0", python = "<3.8"} -python = ">=3.6.7, <3.9" +python = "^3.11.0" scikit-learn = "^0.22.0" tensorflow = "^2.4.0"