Skip to content

Commit

Permalink
Warn if valid data is shuffled (#710)
Browse files Browse the repository at this point in the history
Check for iterator_train__shuffle=True and raise a UserWarning if found.
  • Loading branch information
BenjaminBossan authored Oct 31, 2020
1 parent d19d98c commit 7ecf4b3
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
10 changes: 10 additions & 0 deletions skorch/net.py
Original file line number Diff line number Diff line change
Expand Up @@ -1516,6 +1516,16 @@ def _check_kwargs(self, kwargs):
did you mean iterator_train__shuffle?
"""
# warn about usage of iterator_valid__shuffle=True, since this
# is almost certainly not what the user wants
if kwargs.get('iterator_valid__shuffle'):
warnings.warn(
"You set iterator_valid__shuffle=True; this is most likely not "
"what you want because the values returned by predict and "
"predict_proba will be shuffled.",
UserWarning)

# check for wrong arguments
unexpected_kwargs = []
missing_dunder_kwargs = []
for key in kwargs:
Expand Down
18 changes: 18 additions & 0 deletions skorch/tests/test_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,24 @@ def __init__(self, *args, **kwargs):
# "optimizer_2".
MyNet(module_cls, optimizer_2__lr=0.123) # should not raise

def test_net_init_with_iterator_valid_shuffle_true(
self, net_cls, module_cls, recwarn):
# If a user sets iterator_valid__shuffle=True, they might be
# in for a surprise, since predict et al. will result in
# shuffled predictions. It is best to warn about this, since
# most of the times, this is not what users actually want.
expected = (
"You set iterator_valid__shuffle=True; this is most likely not what you want "
"because the values returned by predict and predict_proba will be shuffled.")

# no warning expected here
net_cls(module_cls, iterator_valid__shuffle=False)
assert not recwarn.list

# warning expected here
with pytest.warns(UserWarning, match=expected):
net_cls(module_cls, iterator_valid__shuffle=True)

def test_fit(self, net_fit):
# fitting does not raise anything
pass
Expand Down

0 comments on commit 7ecf4b3

Please sign in to comment.