Key implementations in official code
All existing layers in both networks remain trainable throughout the training process.
When new layers are added to the networks, we fade them in smoothly.
alpha is flipped in official code
implementation_alpha = (1.0 - paper_alpha)
Minibatch standard deviations
Adding a minibatch layer towards the end of the discriminator.
def minibatch_stddev_layer (x , group_size = 4 , num_new_features = 1 ):
with tf .variable_scope ('MinibatchStddev' ):
group_size = tf .minimum (group_size , tf .shape (x )[0 ])
s = x .shape
y = tf .reshape (x , [group_size , - 1 , num_new_features , s [1 ] // num_new_features , s [2 ], s [3 ]])
y = tf .cast (y , tf .float32 )
y -= tf .reduce_mean (y , axis = 0 , keepdims = True )
y = tf .reduce_mean (tf .square (y ), axis = 0 )
y = tf .sqrt (y + 1e-8 )
y = tf .reduce_mean (y , axis = [2 , 3 , 4 ], keepdims = True )
y = tf .reduce_mean (y , axis = [2 ])
y = tf .cast (y , x .dtype )
y = tf .tile (y , [group_size , 1 , s [2 ], s [3 ]])
return tf .concat ([x , y ], axis = 1 )
initialize weights via N(0, 1)
but scales with per-layer normalization constant from He’s initializer.
w^_i = w_i / c
c = sqrt(2.0/number_of_inputs)
Reduced learning rate for fully connected layers in mapping network: lambda^ = 0.01 * lambda
def get_weight (weight_shape , gain , lrmul ):
fan_in = np .prod (weight_shape [:- 1 ]) # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]
he_std = gain / np .sqrt (fan_in ) # He init
# equalized learning rate
init_std = 1.0 / lrmul
runtime_coef = he_std * lrmul
# create variable.
weight = tf .get_variable ('weight' , shape = weight_shape , dtype = tf .float32 ,
initializer = tf .initializers .random_normal (0 , init_std )) * runtime_coef
return weight
Style-based Generator
Learned affine transformations then specialize w to styles y = (y_s, y_b)
that control adaptive instance normalization (AdaIN) operations after each convolution layer of the synthesis network g.
def instance_norm (x , epsilon = 1e-8 ):
# x: [?, 512, h, w]
assert len (x .shape ) == 4 # NCHW
with tf .variable_scope ('InstanceNorm' ):
epsilon = tf .constant (epsilon , dtype = tf .float32 , name = 'epsilon' )
# tf.reduce_mean(x, axis=[2, 3], keepdims=True): [?, 512, 1, 1]
# x: [?, 512, h, w]
x = x - tf .reduce_mean (x , axis = [2 , 3 ], keepdims = True )
# tf.reduce_mean(tf.square(x), axis=[2, 3], keepdims=True): [?, 512, 1, 1]
# x: [?, 512, h, w]
x = x * tf .rsqrt (tf .reduce_mean (tf .square (x ), axis = [2 , 3 ], keepdims = True ) + epsilon )
return x
def style_mod (x , w ):
# x: [?, 512, h, w]
# w: [?, 512]
with tf .variable_scope ('StyleMod' ):
# units: 1024
units = x .shape [1 ] * 2
# style: [?, 1024]
style = equalized_dense (w , units , gain = 1.0 , lrmul = 1.0 )
style = apply_bias (style , lrmul = 1.0 )
# style: [?, 2, 512, 1, 1]
style = tf .reshape (style , [- 1 , 2 , x .shape [1 ]] + [1 ] * (len (x .shape ) - 2 ))
# x * (style[:, 0] + 1): [?, 512, h, w]
# x: [?, 512, h, w]
x = x * (style [:, 0 ] + 1 ) + style [:, 1 ]
return x
def adaptive_instance_norm (x , w ):
x = instance_norm (x )
x = style_mod (x , w )
return x
Style mixing regularizations
To further encourage the styles to localize, we employ mixing regularization, where a given percentage of images are generated using two random latent codes instead of one during training.
def style_mixing_regularization (z , w_dim , w_broadcasted , n_mapping , n_broadcast ,
train_res_block_idx , style_mixing_prob ):
with tf .name_scope ('StyleMix' ):
z2 = tf .random_normal (tf .shape (z ), dtype = tf .float32 )
w_broadcasted2 = g_mapping (z2 , w_dim , n_mapping , n_broadcast )
layer_indices = np .arange (n_broadcast )[np .newaxis , :, np .newaxis ]
last_layer_index = (train_res_block_idx + 1 ) * 2
mixing_cutoff = tf .cond (
tf .random_uniform ([], 0.0 , 1.0 ) < style_mixing_prob ,
lambda : tf .random_uniform ([], 1 , last_layer_index , dtype = tf .int32 ),
lambda : tf .constant (last_layer_index , dtype = tf .int32 ))
w_broadcasted = tf .where (tf .broadcast_to (layer_indices < mixing_cutoff , tf .shape (w_broadcasted )),
w_broadcasted ,
w_broadcasted2 )
return w_broadcasted
drawing latent vectors from a truncated or otherwise shrunk sampling
space tends to improve average image quality, although some amount of variation is lost.
def truncation_trick (n_broadcast , w_broadcasted , w_avg , truncation_psi , truncation_cutoff ):
with tf .variable_scope ('Truncation' ):
layer_indices = np .arange (n_broadcast )[np .newaxis , :, np .newaxis ]
ones = np .ones (layer_indices .shape , dtype = np .float32 )
coefs = tf .where (layer_indices < truncation_cutoff , truncation_psi * ones , ones )
w_broadcasted = lerp (w_avg , w_broadcasted , coefs )
return w_broadcasted
logistic nonsaturating with gradient penalty
def compute_loss (real_images , real_scores , fake_scores ):
r1_gamma , r2_gamma = 10.0 , 0.0
# discriminator loss: gradient penalty
d_loss_gan = tf .nn .softplus (fake_scores ) + tf .nn .softplus (- real_scores )
real_loss = tf .reduce_sum (real_scores )
real_grads = tf .gradients (real_loss , [real_images ])[0 ]
r1_penalty = tf .reduce_sum (tf .square (real_grads ), axis = [1 , 2 , 3 ])
# r1_penalty = tf.reduce_mean(r1_penalty)
d_loss = d_loss_gan + r1_penalty * (r1_gamma * 0.5 )
d_loss = tf .reduce_mean (d_loss )
# generator loss: logistic nonsaturating
g_loss = tf .nn .softplus (- fake_scores )
g_loss = tf .reduce_mean (g_loss )
return d_loss , g_loss , tf .reduce_mean (d_loss_gan ), tf .reduce_mean (r1_penalty )