Skip to content

Commit

Permalink
Unify BN behavior across backends (fix)
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Aug 3, 2016
1 parent 97d2a73 commit 0588393
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 16 deletions.
18 changes: 9 additions & 9 deletions keras/backend/tensorflow_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -622,10 +622,10 @@ def normalize_batch_in_training(x, gamma, beta,
reduction_axes, epsilon=0.0001):
'''Compute mean and std for batch then apply batch_normalization on batch.
'''
mean, std = tf.nn.moments(x, reduction_axes,
mean, var = tf.nn.moments(x, reduction_axes,
shift=None, name=None, keep_dims=False)
if sorted(reduction_axes) == range(ndim(x))[:-1]:
normed = tf.nn.batch_normalization(x, mean, std,
normed = tf.nn.batch_normalization(x, mean, var,
beta, gamma,
epsilon)
else:
Expand All @@ -639,21 +639,21 @@ def normalize_batch_in_training(x, gamma, beta,
target_shape = tf.pack(target_shape)

broadcast_mean = tf.reshape(mean, target_shape)
broadcast_std = tf.reshape(std, target_shape)
broadcast_var = tf.reshape(var, target_shape)
broadcast_gamma = tf.reshape(gamma, target_shape)
broadcast_beta = tf.reshape(beta, target_shape)
normed = tf.nn.batch_normalization(x, broadcast_mean, broadcast_std,
normed = tf.nn.batch_normalization(x, broadcast_mean, broadcast_var,
broadcast_beta, broadcast_gamma,
epsilon)
return normed, mean, std
return normed, mean, var


def batch_normalization(x, mean, std, beta, gamma, epsilon=0.0001):
'''Apply batch normalization on x given mean, std, beta and gamma:
def batch_normalization(x, mean, var, beta, gamma, epsilon=0.0001):
'''Apply batch normalization on x given mean, var, beta and gamma:
output = (x - mean) / (sqrt(std) + epsilon) * gamma + beta
output = (x - mean) / (sqrt(var) + epsilon) * gamma + beta
'''
return tf.nn.batch_normalization(x, mean, std, beta, gamma, epsilon)
return tf.nn.batch_normalization(x, mean, var, beta, gamma, epsilon)


# SHAPE OPERATIONS
Expand Down
14 changes: 7 additions & 7 deletions keras/backend/theano_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ def normalize_batch_in_training(x, gamma, beta,
reduction_axes, epsilon=0.0001):
'''Compute mean and std for batch then apply batch_normalization on batch.
'''
std = T.sqrt(x.var(reduction_axes) + epsilon)
var = x.var(reduction_axes)
mean = x.mean(reduction_axes)

target_shape = []
Expand All @@ -374,20 +374,20 @@ def normalize_batch_in_training(x, gamma, beta,
target_shape = T.stack(*target_shape)

broadcast_mean = T.reshape(mean, target_shape)
broadcast_std = T.reshape(std, target_shape)
broadcast_var = T.reshape(var, target_shape)
broadcast_beta = T.reshape(beta, target_shape)
broadcast_gamma = T.reshape(gamma, target_shape)
normed = batch_normalization(x, broadcast_mean, broadcast_std,
normed = batch_normalization(x, broadcast_mean, broadcast_var,
broadcast_beta, broadcast_gamma,
epsilon)
return normed, mean, std
return normed, mean, var


def batch_normalization(x, mean, std, beta, gamma, epsilon=0.0001):
'''Apply batch normalization on x given mean, std, beta and gamma.
def batch_normalization(x, mean, var, beta, gamma, epsilon=0.0001):
'''Apply batch normalization on x given mean, var, beta and gamma.
'''
normed = T.nnet.bn.batch_normalization(x, gamma, beta, mean,
sqrt(std) + epsilon,
sqrt(var) + epsilon,
mode='high_mem')
return normed

Expand Down

0 comments on commit 0588393

Please sign in to comment.