Skip to content

Commit

Permalink
add controlnet to examples
Browse files Browse the repository at this point in the history
  • Loading branch information
chengzeyi committed Dec 11, 2023
1 parent 18c1eed commit c71e0a2
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 18 deletions.
42 changes: 36 additions & 6 deletions examples/optimize_lcm_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
CUSTOM_PIPELINE = None
SCHEDULER = 'LCMScheduler'
LORA = 'latent-consistency/lcm-lora-sdv1-5'
CONTROLNET = None
STEPS = 4
PROMPT = 'best quality, realistic, unreal engine, 4K, a beautiful girl'
SEED = None
Expand All @@ -18,7 +19,7 @@
import time
import json
import torch
from PIL import Image
from PIL import (Image, ImageDraw)
from sfast.compilers.stable_diffusion_pipeline_compiler import (
compile, CompilationConfig)

Expand All @@ -30,6 +31,7 @@ def parse_args():
parser.add_argument('--custom-pipeline', type=str, default=CUSTOM_PIPELINE)
parser.add_argument('--scheduler', type=str, default=SCHEDULER)
parser.add_argument('--lora', type=str, default=LORA)
parser.add_argument('--controlnet', type=str, default=None)
parser.add_argument('--steps', type=int, default=STEPS)
parser.add_argument('--prompt', type=str, default=PROMPT)
parser.add_argument('--seed', type=int, default=SEED)
Expand All @@ -41,6 +43,7 @@ def parse_args():
type=str,
default=EXTRA_CALL_KWARGS)
parser.add_argument('--input-image', type=str, default=None)
parser.add_argument('--control-image', type=str, default=None)
parser.add_argument('--output-image', type=str, default=None)
parser.add_argument(
'--compiler',
Expand All @@ -52,15 +55,21 @@ def parse_args():

def load_model(pipeline_cls,
model,
scheduler=None,
custom_pipeline=None,
variant=None,
lora=None):
custom_pipeline=None,
scheduler=None,
lora=None,
controlnet=None):
extra_kwargs = {}
if custom_pipeline is not None:
extra_kwargs['custom_pipeline'] = custom_pipeline
if variant is not None:
extra_kwargs['variant'] = variant
if controlnet is not None:
from diffusers import ControlNetModel
controlnet = ControlNetModel.from_pretrained(controlnet,
torch_dtype=torch.float16)
extra_kwargs['controlnet'] = controlnet
model = pipeline_cls.from_pretrained(model,
torch_dtype=torch.float16,
**extra_kwargs)
Expand Down Expand Up @@ -144,10 +153,11 @@ def main():
model = load_model(
pipeline_cls,
args.model,
scheduler=args.scheduler,
custom_pipeline=args.custom_pipeline,
variant=args.variant,
custom_pipeline=args.custom_pipeline,
scheduler=args.scheduler,
lora=args.lora,
controlnet=args.controlnet,
)
if args.compiler == 'none':
pass
Expand All @@ -167,6 +177,21 @@ def main():
input_image = input_image.resize((args.width, args.height),
Image.LANCZOS)

if args.control_image is None:
if args.controlnet is None:
control_image = None
else:
control_image = Image.new('RGB', (args.width, args.height))
draw = ImageDraw.Draw(control_image)
draw.ellipse((args.width // 4, args.height // 4,
args.width // 4 * 3, args.height // 4 * 3),
fill=(255, 255, 255))
del draw
else:
control_image = Image.open(args.control_image).convert('RGB')
control_image = control_image.resize((args.width, args.height),
Image.LANCZOS)

def get_kwarg_inputs():
kwarg_inputs = dict(
prompt=args.prompt,
Expand All @@ -181,6 +206,11 @@ def get_kwarg_inputs():
)
if input_image is not None:
kwarg_inputs['image'] = input_image
if control_image is not None:
if input_image is None:
kwarg_inputs['image'] = control_image
else:
kwarg_inputs['control_image'] = control_image
return kwarg_inputs

# NOTE: Warm it up.
Expand Down
42 changes: 36 additions & 6 deletions examples/optimize_lcm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
CUSTOM_PIPELINE = 'latent_consistency_txt2img'
SCHEDULER = 'EulerAncestralDiscreteScheduler'
LORA = None
CONTROLNET = None
STEPS = 4
PROMPT = 'best quality, realistic, unreal engine, 4K, a beautiful girl'
SEED = None
Expand All @@ -18,7 +19,7 @@
import time
import json
import torch
from PIL import Image
from PIL import (Image, ImageDraw)
from sfast.compilers.stable_diffusion_pipeline_compiler import (
compile, CompilationConfig)

Expand All @@ -30,6 +31,7 @@ def parse_args():
parser.add_argument('--custom-pipeline', type=str, default=CUSTOM_PIPELINE)
parser.add_argument('--scheduler', type=str, default=SCHEDULER)
parser.add_argument('--lora', type=str, default=LORA)
parser.add_argument('--controlnet', type=str, default=None)
parser.add_argument('--steps', type=int, default=STEPS)
parser.add_argument('--prompt', type=str, default=PROMPT)
parser.add_argument('--seed', type=int, default=SEED)
Expand All @@ -41,6 +43,7 @@ def parse_args():
type=str,
default=EXTRA_CALL_KWARGS)
parser.add_argument('--input-image', type=str, default=None)
parser.add_argument('--control-image', type=str, default=None)
parser.add_argument('--output-image', type=str, default=None)
parser.add_argument(
'--compiler',
Expand All @@ -52,15 +55,21 @@ def parse_args():

def load_model(pipeline_cls,
model,
scheduler=None,
custom_pipeline=None,
variant=None,
lora=None):
custom_pipeline=None,
scheduler=None,
lora=None,
controlnet=None):
extra_kwargs = {}
if custom_pipeline is not None:
extra_kwargs['custom_pipeline'] = custom_pipeline
if variant is not None:
extra_kwargs['variant'] = variant
if controlnet is not None:
from diffusers import ControlNetModel
controlnet = ControlNetModel.from_pretrained(controlnet,
torch_dtype=torch.float16)
extra_kwargs['controlnet'] = controlnet
model = pipeline_cls.from_pretrained(model,
torch_dtype=torch.float16,
**extra_kwargs)
Expand Down Expand Up @@ -144,10 +153,11 @@ def main():
model = load_model(
pipeline_cls,
args.model,
scheduler=args.scheduler,
custom_pipeline=args.custom_pipeline,
variant=args.variant,
custom_pipeline=args.custom_pipeline,
scheduler=args.scheduler,
lora=args.lora,
controlnet=args.controlnet,
)
if args.compiler == 'none':
pass
Expand All @@ -167,6 +177,21 @@ def main():
input_image = input_image.resize((args.width, args.height),
Image.LANCZOS)

if args.control_image is None:
if args.controlnet is None:
control_image = None
else:
control_image = Image.new('RGB', (args.width, args.height))
draw = ImageDraw.Draw(control_image)
draw.ellipse((args.width // 4, args.height // 4,
args.width // 4 * 3, args.height // 4 * 3),
fill=(255, 255, 255))
del draw
else:
control_image = Image.open(args.control_image).convert('RGB')
control_image = control_image.resize((args.width, args.height),
Image.LANCZOS)

def get_kwarg_inputs():
kwarg_inputs = dict(
prompt=args.prompt,
Expand All @@ -181,6 +206,11 @@ def get_kwarg_inputs():
)
if input_image is not None:
kwarg_inputs['image'] = input_image
if control_image is not None:
if input_image is None:
kwarg_inputs['image'] = control_image
else:
kwarg_inputs['control_image'] = control_image
return kwarg_inputs

# NOTE: Warm it up.
Expand Down
42 changes: 36 additions & 6 deletions examples/optimize_stable_diffusion_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
CUSTOM_PIPELINE = None
SCHEDULER = 'EulerAncestralDiscreteScheduler'
LORA = None
CONTROLNET = None
STEPS = 30
PROMPT = 'best quality, realistic, unreal engine, 4K, a beautiful girl'
SEED = None
Expand All @@ -18,7 +19,7 @@
import time
import json
import torch
from PIL import Image
from PIL import (Image, ImageDraw)
from sfast.compilers.stable_diffusion_pipeline_compiler import (
compile, CompilationConfig)

Expand All @@ -30,6 +31,7 @@ def parse_args():
parser.add_argument('--custom-pipeline', type=str, default=CUSTOM_PIPELINE)
parser.add_argument('--scheduler', type=str, default=SCHEDULER)
parser.add_argument('--lora', type=str, default=LORA)
parser.add_argument('--controlnet', type=str, default=None)
parser.add_argument('--steps', type=int, default=STEPS)
parser.add_argument('--prompt', type=str, default=PROMPT)
parser.add_argument('--seed', type=int, default=SEED)
Expand All @@ -41,6 +43,7 @@ def parse_args():
type=str,
default=EXTRA_CALL_KWARGS)
parser.add_argument('--input-image', type=str, default=None)
parser.add_argument('--control-image', type=str, default=None)
parser.add_argument('--output-image', type=str, default=None)
parser.add_argument(
'--compiler',
Expand All @@ -52,15 +55,21 @@ def parse_args():

def load_model(pipeline_cls,
model,
scheduler=None,
custom_pipeline=None,
variant=None,
lora=None):
custom_pipeline=None,
scheduler=None,
lora=None,
controlnet=None):
extra_kwargs = {}
if custom_pipeline is not None:
extra_kwargs['custom_pipeline'] = custom_pipeline
if variant is not None:
extra_kwargs['variant'] = variant
if controlnet is not None:
from diffusers import ControlNetModel
controlnet = ControlNetModel.from_pretrained(controlnet,
torch_dtype=torch.float16)
extra_kwargs['controlnet'] = controlnet
model = pipeline_cls.from_pretrained(model,
torch_dtype=torch.float16,
**extra_kwargs)
Expand Down Expand Up @@ -144,10 +153,11 @@ def main():
model = load_model(
pipeline_cls,
args.model,
scheduler=args.scheduler,
custom_pipeline=args.custom_pipeline,
variant=args.variant,
custom_pipeline=args.custom_pipeline,
scheduler=args.scheduler,
lora=args.lora,
controlnet=args.controlnet,
)
if args.compiler == 'none':
pass
Expand All @@ -167,6 +177,21 @@ def main():
input_image = input_image.resize((args.width, args.height),
Image.LANCZOS)

if args.control_image is None:
if args.controlnet is None:
control_image = None
else:
control_image = Image.new('RGB', (args.width, args.height))
draw = ImageDraw.Draw(control_image)
draw.ellipse((args.width // 4, args.height // 4,
args.width // 4 * 3, args.height // 4 * 3),
fill=(255, 255, 255))
del draw
else:
control_image = Image.open(args.control_image).convert('RGB')
control_image = control_image.resize((args.width, args.height),
Image.LANCZOS)

def get_kwarg_inputs():
kwarg_inputs = dict(
prompt=args.prompt,
Expand All @@ -181,6 +206,11 @@ def get_kwarg_inputs():
)
if input_image is not None:
kwarg_inputs['image'] = input_image
if control_image is not None:
if input_image is None:
kwarg_inputs['image'] = control_image
else:
kwarg_inputs['control_image'] = control_image
return kwarg_inputs

# NOTE: Warm it up.
Expand Down

0 comments on commit c71e0a2

Please sign in to comment.