Skip to content

Commit

Permalink
Enable us to use sklearn to do cv for functional api (#9320)
Browse files Browse the repository at this point in the history
* enable us to use sklearn to do cv for functional api

* adjust code for multiclass

* remove tailing space

* remove irrelevant comment
  • Loading branch information
XinsongDu authored and fchollet committed Feb 8, 2018
1 parent e6c3f77 commit d9f26a9
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions keras/wrappers/scikit_learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,12 @@ def predict(self, x, **kwargs):
Class predictions.
"""
kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
classes = self.model.predict_classes(x, **kwargs)

proba = self.model.predict(x, **kwargs)
if proba.shape[-1] > 1:
classes = proba.argmax(axis=-1)
else:
classes = (proba > 0.5).astype('int32')
return self.classes_[classes]

def predict_proba(self, x, **kwargs):
Expand All @@ -247,7 +252,7 @@ def predict_proba(self, x, **kwargs):
(instead of `(n_sample, 1)` as in Keras).
"""
kwargs = self.filter_sk_params(Sequential.predict_proba, kwargs)
probs = self.model.predict_proba(x, **kwargs)
probs = self.model.predict(x, **kwargs)

# check if binary classification
if probs.shape[1] == 1:
Expand Down

0 comments on commit d9f26a9

Please sign in to comment.