-
Notifications
You must be signed in to change notification settings - Fork 116
/
Copy pathtext_to_image_sdxl.py
172 lines (155 loc) · 4.96 KB
/
text_to_image_sdxl.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
"""
Torch run example: python examples/text_to_image_sdxl.py
Compile with oneflow: python examples/text_to_image_sdxl.py --compiler oneflow
Compile with nexfort: python examples/text_to_image_sdxl.py --compiler nexfort
Test dynamic shape: Add --run_multiple_resolutions 1 and --run_rare_resolutions 1
"""
import argparse
import json
import os
import time
import torch
import oneflow as flow # usort: skip
from diffusers import StableDiffusionXLPipeline
from onediff.infer_compiler import oneflow_compile
from onediff.schedulers import EulerDiscreteScheduler
from onediffx import compile_pipe
parser = argparse.ArgumentParser()
parser.add_argument(
"--base", type=str, default="stabilityai/stable-diffusion-xl-base-1.0"
)
parser.add_argument("--variant", type=str, default="fp16")
parser.add_argument(
"--prompt",
type=str,
default="street style, detailed, raw photo, woman, face, shot on CineStill 800T",
)
parser.add_argument("--height", type=int, default=1024)
parser.add_argument("--width", type=int, default=1024)
parser.add_argument("--n_steps", type=int, default=30)
parser.add_argument("--saved_image", type=str, required=False, default="sdxl-out.png")
parser.add_argument("--seed", type=int, default=1)
# parser.add_argument(
# "--compile_unet",
# type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
# default=True,
# )
# parser.add_argument(
# "--compile_vae",
# type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
# default=True,
# )
parser.add_argument(
"--compiler",
type=str,
default="oneflow",
choices=["oneflow", "nexfort"],
)
parser.add_argument(
"--compiler-config",
type=str,
default=None,
)
parser.add_argument(
"--run_multiple_resolutions",
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
default=True,
)
parser.add_argument(
"--run_rare_resolutions",
type=(lambda x: str(x).lower() in ["true", "1", "yes"]),
default=True,
)
args = parser.parse_args()
# Normal SDXL pipeline init.
OUTPUT_TYPE = "pil"
# SDXL base: StableDiffusionXLPipeline
scheduler = EulerDiscreteScheduler.from_pretrained(args.base, subfolder="scheduler")
base = StableDiffusionXLPipeline.from_pretrained(
args.base,
scheduler=scheduler,
torch_dtype=torch.float16,
variant=args.variant,
use_safetensors=True,
)
base.to("cuda")
# # Compile unet with oneflow
# if args.compile_unet:
# print("Compiling unet with oneflow.")
# base.unet = oneflow_compile(base.unet)
# # Compile vae with oneflow
# if args.compile_vae:
# print("Compiling vae with oneflow.")
# base.vae.decoder = oneflow_compile(base.vae.decoder)
# Compile the pipe
if args.compiler == "oneflow":
base.unet = oneflow_compile(base.unet)
elif args.compiler == "nexfort":
if args.compiler_config is not None:
options = json.loads(args.compiler_config)
else:
options = json.loads('{"mode": "max-autotune:cudagraphs", "dynamic": true}')
base = compile_pipe(
base, backend="nexfort", options=options, fuse_qkv_projections=True
)
# Warmup with run
# Will do compilatioin in the first run
print("Warmup with running graphs...")
torch.manual_seed(args.seed)
image = base(
prompt=args.prompt,
height=args.height,
width=args.width,
num_inference_steps=args.n_steps,
output_type=OUTPUT_TYPE,
).images
# Normal SDXL run
print("Normal SDXL run...")
torch.manual_seed(args.seed)
print(f"Running at resolution: {args.height}x{args.width}")
start_time = time.time()
image = base(
prompt=args.prompt,
height=args.height,
width=args.width,
num_inference_steps=args.n_steps,
output_type=OUTPUT_TYPE,
).images
end_time = time.time()
print(f"Inference time: {end_time - start_time:.2f} seconds")
image[0].save(f"h{args.height}-w{args.width}-{args.saved_image}")
# Should have no compilation for these new input shape
# The nexfort backend encounters an exception when dynamically switching resolution to 960x720.
if args.run_multiple_resolutions:
print("Test run with multiple resolutions...")
sizes = [960, 720, 896, 768]
if "CI" in os.environ:
sizes = [360]
for h in sizes:
for w in sizes:
print(f"Running at resolution: {h}x{w}")
start_time = time.time()
image = base(
prompt=args.prompt,
height=h,
width=w,
num_inference_steps=args.n_steps,
output_type=OUTPUT_TYPE,
).images
end_time = time.time()
print(f"Inference time: {end_time - start_time:.2f} seconds")
if args.run_rare_resolutions:
print("Test run with other another uncommon resolution...")
h = 544
w = 408
print(f"Running at resolution: {h}x{w}")
start_time = time.time()
image = base(
prompt=args.prompt,
height=h,
width=w,
num_inference_steps=args.n_steps,
output_type=OUTPUT_TYPE,
).images
end_time = time.time()
print(f"Inference time: {end_time - start_time:.2f} seconds")