Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: Check input dimensions against the initialized model_ #143

Open
wants to merge 21 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions scikeras/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import warnings

from inspect import isclass
from typing import Any, Callable, Dict, Iterable, List, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Union

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -107,16 +107,18 @@ def make_model_picklable(model_obj):


def _windows_upcast_ints(
arr: Union[List[np.ndarray], np.ndarray]
inp: Union[List[np.ndarray], Mapping[Any, np.ndarray], np.ndarray]
) -> Union[List[np.ndarray], np.ndarray]:
# see tensorflow/probability#886
def _upcast(x):
return x.astype("int64") if x.dtype == np.int32 else x

if isinstance(arr, np.ndarray):
return _upcast(arr)
if isinstance(inp, np.ndarray):
return _upcast(inp)
elif isinstance(inp, Mapping):
return {k: _upcast(x_) for k, x_ in inp.items()}
else:
return [_upcast(x_) for x_ in arr]
return [_upcast(x_) for x_ in inp]


def route_params(
Expand Down
63 changes: 48 additions & 15 deletions scikeras/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import warnings

from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Tuple, Type, Union
from typing import Any, Callable, Dict, Iterable, List, Mapping, Tuple, Type, Union

import numpy as np
import tensorflow as tf
Expand Down Expand Up @@ -530,22 +530,45 @@ def _fit_keras_model(
raise e
self.history_[key] += val

def _check_model_compatibility(self, y: np.ndarray) -> None:
"""Checks that the model output number and y shape match.
def _check_model_outputs(self, y):
# output shapes depend on the number of classes in classification,
# hence we cannot just check y here, we need the user (or the
# data transformer) to tell us what to expect via n_outputs_expected_
# n_outputs_expected_ is generated by data transformers
# and recovered via target_encoder_.get_meta();
# we recognize it but do not force it to be
# generated to avoid forcing users to subclass
# generic transformers just to make them SciKeras compatible
n_out_expect = getattr(self, "n_outputs_expected_", None)
if n_out_expect and n_out_expect != len(self.model_.outputs):
raise ValueError(
"The target ``y`` seems to consist of"
f" {n_out_expect} outputs, but this Keras"
f" Model has {len(self.model_.outputs)} outputs."
)

This is in place to avoid cryptic TF errors.
"""
# check if this is a multi-output model
if hasattr(self, "n_outputs_expected_"):
# n_outputs_expected_ is generated by data transformers
# we recognize the attribute but do not force it to be
# generated
if self.n_outputs_expected_ != len(self.model_.outputs):
def _check_model_inputs(self, X):
if isinstance(X, np.ndarray):
# Keras Model's inputs are always a list
X = [X]
elif isinstance(X, Mapping):
X = [X[inp_name] for inp_name in self.model_.input_names]
if len(X) != len(self.model_.inputs):
raise ValueError(
f"``X`` has {len(X)} inputs, but the Keras model"
f" has {len(self.model_.inputs)} inputs."
)
for X_in, model_in in zip(X, self.model_.inputs):
# check shape
X_in_shape = (1,) if X_in.ndim == 1 else X_in.shape[1:]
model_in_shape = model_in.shape[1:]
if X_in_shape != model_in_shape:
raise ValueError(
"Detected a Keras model input of size"
f" {y[0].shape[0]}, but {self.model_} has"
f" {self.model_.outputs} outputs"
f"Input {model_in.name} expected shape"
f" {model_in_shape} but got {X_in_shape}."
)

def _check_model_loss(self):
# check that if the user gave us a loss function it ended up in
# the actual model
init_params = inspect.signature(self.__init__).parameters
Expand All @@ -565,6 +588,16 @@ def _check_model_compatibility(self, y: np.ndarray) -> None:
" Data may not match loss function!"
)

def _check_model_compatibility(self, X, y) -> None:
"""Checks that the model inputs, outputs and loss
match the given or expected X, y & loss.

This is in place to avoid cryptic TF errors.
"""
self._check_model_inputs(X)
self._check_model_outputs(y)
self._check_model_loss()

def _validate_data(
self, X=None, y=None, reset: bool = False, y_numeric: bool = False
) -> Tuple[np.ndarray, Union[np.ndarray, None]]:
Expand Down Expand Up @@ -865,7 +898,7 @@ def _fit(
y = self.target_encoder_.transform(y)
X = self.feature_encoder_.transform(X)

self._check_model_compatibility(y)
self._check_model_compatibility(X, y)

self._fit_keras_model(
X,
Expand Down
8 changes: 5 additions & 3 deletions tests/test_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@ def test_X_shape_change():
loss=KerasRegressor.r_squared,
hidden_layer_sizes=(100,),
)
X = np.array([[1, 2], [3, 4]]).reshape(2, 2, 1)
X = np.array([[1, 2], [3, 4]])
y = np.array([[0, 1, 0], [1, 0, 0]])

estimator.fit(X=X, y=y)

# Calling with a different number of dimensions for X raises an error
with pytest.raises(ValueError, match="dimensions in X"):
# Calling with a different number of dimensions for X raises an error
estimator.partial_fit(X=X.reshape(2, 2), y=y)
estimator.partial_fit(X=X.reshape(2, 2, 1), y=y)
with pytest.raises(ValueError, match="dimensions in X"):
estimator.predict(X=X.reshape(2, 2, 1))


def test_unknown_param():
Expand Down
Loading