diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index bcb8be0affa..2d23706948d 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -1905,17 +1905,17 @@ def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3): # The mean / var / beta / gamma may be processed by broadcast # so it may have extra axes with 1, it is not needed and should be removed if ndim(mean) > 1: - mean = tf.reshape(mean, (-1)) + mean = tf.reshape(mean, [-1]) if ndim(var) > 1: - var = tf.reshape(var, (-1)) + var = tf.reshape(var, [-1]) if beta is None: beta = zeros_like(mean) elif ndim(beta) > 1: - beta = tf.reshape(beta, (-1)) + beta = tf.reshape(beta, [-1]) if gamma is None: gamma = ones_like(mean) elif ndim(gamma) > 1: - gamma = tf.reshape(gamma, (-1)) + gamma = tf.reshape(gamma, [-1]) y, _, _ = tf.nn.fused_batch_norm( x, gamma,