Skip to content

Latest commit

 

History

History
168 lines (139 loc) · 6.28 KB

implementations.md

File metadata and controls

168 lines (139 loc) · 6.28 KB

Key implementations in official code

From the paper PGGAN

Progressive training

  • All existing layers in both networks remain trainable throughout the training process.
  • When new layers are added to the networks, we fade them in smoothly.
  • alpha is flipped in official code
  • implementation_alpha = (1.0 - paper_alpha)

Minibatch standard deviations

  • Adding a minibatch layer towards the end of the discriminator.
def minibatch_stddev_layer(x, group_size=4, num_new_features=1):
    with tf.variable_scope('MinibatchStddev'):
        group_size = tf.minimum(group_size, tf.shape(x)[0])
        s = x.shape
        y = tf.reshape(x, [group_size, -1, num_new_features, s[1] // num_new_features, s[2], s[3]])
        y = tf.cast(y, tf.float32)
        y -= tf.reduce_mean(y, axis=0, keepdims=True)
        y = tf.reduce_mean(tf.square(y), axis=0)
        y = tf.sqrt(y + 1e-8)
        y = tf.reduce_mean(y, axis=[2, 3, 4], keepdims=True)
        y = tf.reduce_mean(y, axis=[2])
        y = tf.cast(y, x.dtype)
        y = tf.tile(y, [group_size, 1, s[2], s[3]])
        return tf.concat([x, y], axis=1)

Equalized learning rate

  • initialize weights via N(0, 1) but scales with per-layer normalization constant from He’s initializer.
  • w^_i = w_i / c
  • c = sqrt(2.0/number_of_inputs)
  • Reduced learning rate for fully connected layers in mapping network: lambda^ = 0.01 * lambda
def get_weight(weight_shape, gain, lrmul):
    fan_in = np.prod(weight_shape[:-1])  # [kernel, kernel, fmaps_in, fmaps_out] or [in, out]
    he_std = gain / np.sqrt(fan_in)  # He init

    # equalized learning rate
    init_std = 1.0 / lrmul
    runtime_coef = he_std * lrmul

    # create variable.
    weight = tf.get_variable('weight', shape=weight_shape, dtype=tf.float32,
                             initializer=tf.initializers.random_normal(0, init_std)) * runtime_coef
    return weight

From the paper StyleGAN

Generator architecture

  • Style-based Generator

AdaIN

  • Learned affine transformations then specialize w to styles y = (y_s, y_b) that control adaptive instance normalization (AdaIN) operations after each convolution layer of the synthesis network g.
def instance_norm(x, epsilon=1e-8):
    # x: [?, 512, h, w]
    assert len(x.shape) == 4  # NCHW
    with tf.variable_scope('InstanceNorm'):
        epsilon = tf.constant(epsilon, dtype=tf.float32, name='epsilon')

        # tf.reduce_mean(x, axis=[2, 3], keepdims=True): [?, 512, 1, 1]
        # x: [?, 512, h, w]
        x = x - tf.reduce_mean(x, axis=[2, 3], keepdims=True)

        # tf.reduce_mean(tf.square(x), axis=[2, 3], keepdims=True): [?, 512, 1, 1]
        # x: [?, 512, h, w]
        x = x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=[2, 3], keepdims=True) + epsilon)
    return x


def style_mod(x, w):
    # x: [?, 512, h, w]
    # w: [?, 512]
    with tf.variable_scope('StyleMod'):
        # units: 1024
        units = x.shape[1] * 2

        # style: [?, 1024]
        style = equalized_dense(w, units, gain=1.0, lrmul=1.0)
        style = apply_bias(style, lrmul=1.0)

        # style: [?, 2, 512, 1, 1]
        style = tf.reshape(style, [-1, 2, x.shape[1]] + [1] * (len(x.shape) - 2))

        # x * (style[:, 0] + 1): [?, 512, h, w]
        # x: [?, 512, h, w]
        x = x * (style[:, 0] + 1) + style[:, 1]
    return x


def adaptive_instance_norm(x, w):
    x = instance_norm(x)
    x = style_mod(x, w)
    return x

Style mixing regularizations

  • To further encourage the styles to localize, we employ mixing regularization, where a given percentage of images are generated using two random latent codes instead of one during training.
def style_mixing_regularization(z, w_dim, w_broadcasted, n_mapping, n_broadcast,
                                train_res_block_idx, style_mixing_prob):
    with tf.name_scope('StyleMix'):
        z2 = tf.random_normal(tf.shape(z), dtype=tf.float32)
        w_broadcasted2 = g_mapping(z2, w_dim, n_mapping, n_broadcast)
        layer_indices = np.arange(n_broadcast)[np.newaxis, :, np.newaxis]
        last_layer_index = (train_res_block_idx + 1) * 2
        mixing_cutoff = tf.cond(
            tf.random_uniform([], 0.0, 1.0) < style_mixing_prob,
            lambda: tf.random_uniform([], 1, last_layer_index, dtype=tf.int32),
            lambda: tf.constant(last_layer_index, dtype=tf.int32))
        w_broadcasted = tf.where(tf.broadcast_to(layer_indices < mixing_cutoff, tf.shape(w_broadcasted)),
                                 w_broadcasted,
                                 w_broadcasted2)
    return w_broadcasted

Truncation trick in w

  • drawing latent vectors from a truncated or otherwise shrunk sampling space tends to improve average image quality, although some amount of variation is lost.
def truncation_trick(n_broadcast, w_broadcasted, w_avg, truncation_psi, truncation_cutoff):
    with tf.variable_scope('Truncation'):
        layer_indices = np.arange(n_broadcast)[np.newaxis, :, np.newaxis]
        ones = np.ones(layer_indices.shape, dtype=np.float32)
        coefs = tf.where(layer_indices < truncation_cutoff, truncation_psi * ones, ones)
        w_broadcasted = lerp(w_avg, w_broadcasted, coefs)
    return w_broadcasted

Losses

  • logistic nonsaturating with gradient penalty
def compute_loss(real_images, real_scores, fake_scores):
    r1_gamma, r2_gamma = 10.0, 0.0

    # discriminator loss: gradient penalty
    d_loss_gan = tf.nn.softplus(fake_scores) + tf.nn.softplus(-real_scores)
    real_loss = tf.reduce_sum(real_scores)
    real_grads = tf.gradients(real_loss, [real_images])[0]
    r1_penalty = tf.reduce_sum(tf.square(real_grads), axis=[1, 2, 3])
    # r1_penalty = tf.reduce_mean(r1_penalty)
    d_loss = d_loss_gan + r1_penalty * (r1_gamma * 0.5)
    d_loss = tf.reduce_mean(d_loss)

    # generator loss: logistic nonsaturating
    g_loss = tf.nn.softplus(-fake_scores)
    g_loss = tf.reduce_mean(g_loss)
    return d_loss, g_loss, tf.reduce_mean(d_loss_gan), tf.reduce_mean(r1_penalty)