Skip to content

Commit

Permalink
Resolve keras-team#2960
Browse files Browse the repository at this point in the history
Introduce `K.var` so that the standard deviation computation can
be made numerically stable. Instead of

	K.std(x)

the user is able to write

	K.sqrt(K.var(x) + self.epsilon)

avoiding a division by zero in the gradient computation of `sqrt`.
  • Loading branch information
nemo committed Jun 13, 2016
1 parent 3b83a1b commit e969872
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
16 changes: 11 additions & 5 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,17 +314,23 @@ def prod(x, axis=None, keepdims=False):
return tf.reduce_prod(x, reduction_indices=axis, keep_dims=keepdims)


def std(x, axis=None, keepdims=False):
'''Standard deviation of a tensor, alongside the specificied axis.
def var(x, axis=None, keepdims=False):
'''Variance of a tensor, alongside the specificied axis.
'''
axis = _normalize_axis(axis, ndim(x))
if x.dtype.base_dtype == tf.bool:
x = tf.cast(x, _FLOATX)
m = tf.reduce_mean(x, reduction_indices=axis, keep_dims=True)
devs_squared = tf.square(x - m)
return tf.sqrt(tf.reduce_mean(devs_squared,
reduction_indices=axis,
keep_dims=keepdims))
return tf.reduce_mean(devs_squared,
reduction_indices=axis,
keep_dims=keepdims)


def std(x, axis=None, keepdims=False):
'''Standard deviation of a tensor, alongside the specificied axis.
'''
return tf.sqrt(var(x, axis=axis, keepdims=keepdims))


def mean(x, axis=None, keepdims=False):
Expand Down
4 changes: 4 additions & 0 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def std(x, axis=None, keepdims=False):
return T.std(x, axis=axis, keepdims=keepdims)


def var(x, axis=None, keepdims=False):
return T.var(x, axis=axis, keepdims=keepdims)


def any(x, axis=None, keepdims=False):
'''Bitwise reduction (logical OR).
'''
Expand Down
2 changes: 1 addition & 1 deletion keras/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def call(self, x, mask=None):
elif self.mode == 1:
# sample-wise normalization
m = K.mean(x, axis=-1, keepdims=True)
std = K.std(x, axis=-1, keepdims=True)
std = K.sqrt(K.var(x, axis=-1, keepdims=True) + self.epsilon)
x_normed = (x - m) / (std + self.epsilon)
out = self.gamma * x_normed + self.beta
return out
Expand Down

0 comments on commit e969872

Please sign in to comment.