Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for StyleGAN2-ADA #9

Open
IvonaTau opened this issue Jun 10, 2021 · 2 comments
Open

Support for StyleGAN2-ADA #9

IvonaTau opened this issue Jun 10, 2021 · 2 comments

Comments

@IvonaTau
Copy link

Is it possible to make the code work with .pkl trained using StyleGAN2-ADA ?

The default code throws an error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-9-ddd14bed1418> in <module>
      2 for res in resolutions:
      3     filename = f"blended-{res}.jpg"
----> 4     blend_models.main(low_res_pkl, high_res_pkl, res, output_grid=filename)
      5     img = Image(filename=filename)
      6     print(f"blending at {res}x{res}")

/cluster-polyaxon/users/itau/my_projects/art/model-blending/stylegan2/blend_models.py in main(low_res_pkl, high_res_pkl, resolution, level, blend_width, output_grid, seed, output_pkl, verbose)
    105 
    106     with tf.Session() as sess, tf.device('/gpu:0'):
--> 107             low_res_G, low_res_D, low_res_Gs = misc.load_pkl(low_res_pkl)
    108             high_res_G, high_res_D, high_res_Gs = misc.load_pkl(high_res_pkl)
    109 

/cluster-polyaxon/users/itau/my_projects/art/model-blending/stylegan2/training/misc.py in load_pkl(file_or_url)
     26 def load_pkl(file_or_url):
     27     with open_file_or_url(file_or_url) as file:
---> 28         return pickle.load(file, encoding='latin1')
     29 
     30 def locate_latest_pkl(result_dir):

/cluster-polyaxon/users/itau/my_projects/art/model-blending/stylegan2/dnnlib/tflib/network.py in __setstate__(self, state)
    295 
    296         # Init TensorFlow graph.
--> 297         self._init_graph()
    298         self.reset_own_vars()
    299         tfutil.set_vars({self.find_var(name): value for name, value in state["variables"]})

/cluster-polyaxon/users/itau/my_projects/art/model-blending/stylegan2/dnnlib/tflib/network.py in _init_graph(self)
    152             with tf.control_dependencies(None):  # ignore surrounding control dependencies
    153                 self.input_templates = [tf.placeholder(tf.float32, name=name) for name in self.input_names]
--> 154                 out_expr = self._build_func(*self.input_templates, **build_kwargs)
    155 
    156         # Collect outputs.

<string> in G_synthesis(dlatents_in, dlatent_size, num_channels, resolution, fmap_base, fmap_decay, fmap_min, fmap_max, fmap_const, use_noise, randomize_noise, architecture, nonlinearity, dtype, num_fp16_res, conv_clamp, resample_kernel, fused_modconv, **_kwargs)

<string> in layer(x, layer_idx, fmaps, kernel, up)

<string> in modulated_conv2d_layer(x, y, fmaps, kernel, up, down, demodulate, resample_kernel, lrmul, fused_modconv, trainable, use_spectral_norm)

<string> in apply_bias_act(x, act, gain, lrmul, clamp, bias_var, trainable)

TypeError: fused_bias_act() got an unexpected keyword argument 'clamp'
@IvonaTau
Copy link
Author

OK, I actually made it work by using the original stylegan2-ada repo with the following modfied blend_models.py script added to the main directory:

import tensorflow as tf
import sys, getopt, os

import numpy as np
import dnnlib
import dnnlib.tflib as tflib
from dnnlib.tflib import tfutil
from dnnlib.tflib.autosummary import autosummary
import math
import numpy as np

from training import dataset
# from training import misc
import pickle
import PIL.Image
import PIL.ImageFont

from pathlib import Path
import typer
from typing import Optional


def load_pkl(file_or_url):
    with open_file_or_url(file_or_url) as file:
        return pickle.load(file, encoding='latin1')
    
def open_file_or_url(file_or_url):
    if dnnlib.util.is_url(file_or_url):
        return dnnlib.util.open_url(file_or_url, cache_dir='.stylegan2-cache')
    return open(file_or_url, 'rb')

def adjust_dynamic_range(data, drange_in, drange_out):
    if drange_in != drange_out:
        scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0]))
        bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
        data = data * scale + bias
    return data

def create_image_grid(images, grid_size=None):
    assert images.ndim == 3 or images.ndim == 4
    num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2]

    if grid_size is not None:
        grid_w, grid_h = tuple(grid_size)
    else:
        grid_w = max(int(np.ceil(np.sqrt(num))), 1)
        grid_h = max((num - 1) // grid_w + 1, 1)

    grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype)
    for idx in range(num):
        x = (idx % grid_w) * img_w
        y = (idx // grid_w) * img_h
        grid[..., y : y + img_h, x : x + img_w] = images[idx]
    return grid

def convert_to_pil_image(image, drange=[0,1]):
    assert image.ndim == 2 or image.ndim == 3
    if image.ndim == 3:
        if image.shape[0] == 1:
            image = image[0] # grayscale CHW => HW
        else:
            image = image.transpose(1, 2, 0) # CHW -> HWC

    image = adjust_dynamic_range(image, drange, [0,255])
    image = np.rint(image).clip(0, 255).astype(np.uint8)
    fmt = 'RGB' if image.ndim == 3 else 'L'
    return PIL.Image.fromarray(image, fmt)

def save_image_grid(images, filename, drange=[0,1], grid_size=None):
    convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename)

def save_pkl(obj, filename):
    with open(filename, 'wb') as file:
        pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)
        

def extract_conv_names(model):
    # layers are G_synthesis/{res}x{res}/...
    # make a list of (name, resolution, level, position)
    # Currently assuming square(?)

    model_names = list(model.trainables.keys())
    conv_names = []

    resolutions =  [4*2**x for x in range(9)]

    level_names = [["Conv0_up", "Const"],
                    ["Conv1", "ToRGB"]]
    
    position = 0
    # option not to split levels
    for res in resolutions:
        root_name = f"G_synthesis/{res}x{res}/"
        for level, level_suffixes in enumerate(level_names):
            for suffix in level_suffixes:
                search_name = root_name + suffix
                matched_names = [x for x in model_names if x.startswith(search_name)]
                to_add = [(name, f"{res}x{res}", level, position) for name in matched_names]
                conv_names.extend(to_add)
            position += 1

    return conv_names


def blend_models(model_1, model_2, resolution, level, blend_width=None, verbose=False):

    # y is the blending amount which y = 0 means all model 1, y = 1 means all model_2

    # TODO add small x offset for smoother blend animations
    resolution = f"{resolution}x{resolution}"
    
    model_1_names = extract_conv_names(model_1)
    model_2_names = extract_conv_names(model_2)

    assert all((x == y for x, y in zip(model_1_names, model_2_names)))

    model_out = model_1.clone()

    short_names = [(x[1:3]) for x in model_1_names]
    full_names = [(x[0]) for x in model_1_names]
    mid_point_idx = short_names.index((resolution, level))
    mid_point_pos = model_1_names[mid_point_idx][3]
    
    ys = []
    for name, resolution, level, position in model_1_names:
        # low to high (res)
        x = position - mid_point_pos
        if blend_width:
            exponent = -x/blend_width
            y = 1 / (1 + math.exp(exponent))
        else:
            y = 1 if x > 1 else 0

        ys.append(y)
        if verbose:
            print(f"Blending {name} by {y}")

    tfutil.set_vars(
        tfutil.run(
            {model_out.vars[name]: (model_2.vars[name] * y + model_1.vars[name] * (1-y))
             for name, y 
             in zip(full_names, ys)}
        )
    )

    return model_out

def main(low_res_pkl: Path, # Pickle file from which to take low res layers
         high_res_pkl: Path, # Pickle file from which to take high res layers
         resolution: int, # Resolution level at which to switch between models
         level: int  = 0, # Switch at Conv block 0 or 1?
         blend_width: Optional[float] = None, # None = hard switch, float = smooth switch (logistic) with given width
         output_grid: Optional[Path] = "blended.jpg", # Path of image file to save example grid (None = don't save)
         seed: int = 0, # seed for random grid
         output_pkl: Optional[Path] = None, # Output path of pickle (None = don't save)
         verbose: bool = False, # Print out the exact blending fraction
         ):

    grid_size = (3, 3)

    tflib.init_tf()

    with tf.Session() as sess, tf.device('/gpu:0'):
            low_res_G, low_res_D, low_res_Gs = load_pkl(low_res_pkl)
            high_res_G, high_res_D, high_res_Gs = load_pkl(high_res_pkl)

            out = blend_models(low_res_Gs, high_res_Gs, resolution, level, blend_width=blend_width, verbose=verbose)

            if output_grid:
                rnd = np.random.RandomState(seed)
                grid_latents = rnd.randn(np.prod(grid_size), *out.input_shape[1:])
                grid_fakes = out.run(grid_latents, None, is_validation=True, minibatch_size=1)
                save_image_grid(grid_fakes, output_grid, drange= [-1,1], grid_size=grid_size)

            # TODO modify all the networks
            if output_pkl:
                save_pkl((low_res_G, low_res_D, out), output_pkl)
            
        
if __name__ == '__main__':
    typer.run(main)
    

@Norod
Copy link

Norod commented Jun 10, 2021

OK, I actually made it work by using the original stylegan2-ada repo with the following modfied blend_models.py script added to the main directory:

Indeed. Here is a Colab from a few months ago which demonstrates this trick: https://github.com/Norod/my-colab-experiments/blob/master/Buntworthy_StyleGAN_ADA_blending_example.ipynb

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants