forked from NVlabs/stylegan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
pretrained_example.py
executable file
·54 lines (44 loc) · 1.9 KB
/
pretrained_example.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
#
# This work is licensed under the Creative Commons Attribution-NonCommercial
# 4.0 International License. To view a copy of this license, visit
# http://creativecommons.org/licenses/by-nc/4.0/ or send a letter to
# Creative Commons, PO Box 1866, Mountain View, CA 94042, USA.
"""Minimal script for generating an image using pre-trained StyleGAN generator."""
import os
import pickle
import numpy as np
import PIL.Image
import dnnlib
import dnnlib.tflib as tflib
import config
def main():
# Initialize TensorFlow.
tflib.init_tf()
# Load pre-trained network.
# url = 'https://drive.google.com/uc?id=1MEGjdvVpUsu1jB4zrXZN7Y4kBBOzizDQ' # karras2019stylegan-ffhq-1024x1024.pkl
# with dnnlib.util.open_url(url, cache_dir=config.cache_dir) as f:
url = 'results/00001-sgan-evan-1024-2gpu/network-snapshot-025000.pkl'
with open(url, "rb") as f:
_G, _D, Gs = pickle.load(f)
# _G = Instantaneous snapshot of the generator. Mainly useful for resuming a previous training run.
# _D = Instantaneous snapshot of the discriminator. Mainly useful for resuming a previous training run.
# Gs = Long-term average of the generator. Yields higher-quality results than the instantaneous snapshot.
# Print network details.
Gs.print_layers()
rnd = np.random.RandomState()
N = Gs.input_shape[1]
latents = rnd.randn(1, N)
delta = 1e-2
direction = rnd.randint(0, N)
for i in range(1000):
latents[0, direction] += delta
# Generate image.
fmt = dict(func=tflib.convert_images_to_uint8, nchw_to_nhwc=True)
images = Gs.run(latents, None, truncation_psi=0.7, randomize_noise=True, output_transform=fmt)
# Save image.
os.makedirs(config.result_dir, exist_ok=True)
png_filename = os.path.join(config.result_dir, 'example%04d.png' % i)
PIL.Image.fromarray(images[0], 'RGB').save(png_filename)
if __name__ == "__main__":
main()