Skip to content

Latest commit

 

History

History
104 lines (74 loc) · 4.22 KB

README.md

File metadata and controls

104 lines (74 loc) · 4.22 KB

Unofficial Tensorflow 2 Implementation Of FreeU: Free Lunch in Diffusion U-Net

Paper | Project Page | Video

FreeU, a method that substantially improves diffusion model sample quality at no costs: no training, no additional parameter introduced, and no increase in memory or sampling time.

📖 For more visual results, go checkout official project page

FreeU Code

import numpy as np
import tensorflow as tf


def Fourier_filter(x, threshold=1, scale=0.9):
    x_dtype = x.dtype
    x = tf.cast(x, tf.float32)
    # FFT
    x_freq = tf.signal.fft3d(tf.cast(x, dtype=tf.complex64))
    x_freq = tf.signal.fftshift(x_freq, axes=(1, 2, 3))
    B, H, W, C = x_freq.get_shape().as_list()
    mask = np.ones((1, H, W, C), dtype=np.complex64)

    crow, ccol = H // 2, W // 2
    mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold, :] = scale
    x_freq = x_freq * mask

    # IFFT
    x_freq = tf.signal.ifftshift(x_freq, axes=(1, 2, 3))
    x_filtered = tf.signal.ifft3d(x_freq)
    x_filtered = tf.math.real(x_filtered)
    return tf.cast(x_filtered, x_dtype)


def free_u(h, hs_, active=False, b1=1.2, b2=1.4, s1=0.9, s2=0.2, axis=-1):
    if active:
        if h.get_shape().as_list()[axis] == 1280:
            h1, h2 = tf.split(h, num_or_size_splits=2, axis=axis)
            h = tf.keras.layers.Concatenate(axis=axis)([h1 * b1, h2])
            hs_ = Fourier_filter(hs_, threshold=1, scale=s1)
        if h.get_shape().as_list()[axis] == 640:
            h1, h2 = tf.split(h, num_or_size_splits=2, axis=axis)
            h = tf.keras.layers.Concatenate(axis=axis)([h1 * b2, h2])
            hs_ = Fourier_filter(hs_, threshold=1, scale=s2)
    return tf.keras.layers.Concatenate(axis=axis)([h, hs_])

Parameters

Feel free to adjust these parameters based on your models, image/video style, or tasks. The following parameters are for reference only.

SD1.x:

b1: 1.2, b2: 1.4, s1: 0.9, s2: 0.2

SD2.x

b1: 1.1, b2: 1.2, s1: 0.9, s2: 0.2

Range for More Parameters

When trying additional parameters, consider the following ranges:

  • b1: 1 ≤ b1 ≤ 1.2
  • b2: 1.2 ≤ b2 ≤ 1.6
  • s1: s1 ≤ 1
  • s2: s2 ≤ 1

Results from the community

If you tried FreeU and want to share your results, let me know and we can put up the link here.

Distributed under the MIT License. See LICENSE for more information.

Credits

Licenses for borrowed code can be found in following link:

Donating

If this project useful for you, please consider buying me a cup of coffee or sponsoring me!

Buy Me A Coffee