Skip to content

Commit

Permalink
More fixes for demo.
Browse files Browse the repository at this point in the history
  • Loading branch information
monorimet committed Jun 2, 2024
1 parent f9a1d35 commit 02285b3
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 42 deletions.
19 changes: 10 additions & 9 deletions apps/shark_studio/api/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
external_weights: str = "safetensors",
progress=gr.Progress(),
):
progress(0, desc="Initializing pipeline...")
progress(None, desc="Initializing pipeline...")
self.ui_device = device
self.precision = precision
self.compiled_pipeline = False
Expand Down Expand Up @@ -181,7 +181,7 @@ def __init__(
external_weights=external_weights,
custom_vae=custom_vae,
)
progress(1, desc="Pipeline initialized!...")
progress(None, desc="Pipeline initialized!...")
gc.collect()

def prepare_pipe(
Expand Down Expand Up @@ -245,18 +245,18 @@ def prepare_pipe(
"diffusion_pytorch_model.safetensors",
)
weights[key] = save_irpa(vae_weights_path, "vae.")
progress(0, desc=f"Preparing pipeline for {self.ui_device}...")
progress(None, desc=f"Preparing pipeline for {self.ui_device}...")

vmfbs, weights = self.sd_pipe.check_prepared(
mlirs, vmfbs, weights, interactive=False
)
progress(1, desc=f"Artifacts ready!")
progress(0, desc=f"Loading pipeline on device {self.ui_device}...")
progress(None, desc=f"Artifacts ready!")
progress(None, desc=f"Loading pipeline on device {self.ui_device}...")

self.sd_pipe.load_pipeline(
vmfbs, weights, self.rt_device, self.compiled_pipeline
)
progress(1, desc="Pipeline loaded!")
progress(None, desc="Pipeline loaded! Generating images...")
return

def generate_images(
Expand All @@ -271,9 +271,9 @@ def generate_images(
resample_type,
control_mode,
hints,
progress=gr.Progress(track_tqdm=True),
progress=gr.Progress()
):
progress(0, desc="Generating images...")

img = self.sd_pipe.generate_images(
prompt,
negative_prompt,
Expand All @@ -282,7 +282,6 @@ def generate_images(
seed,
return_imgs=True,
)
progress(1, desc="Image generation complete!")
return img


Expand Down Expand Up @@ -453,6 +452,8 @@ def shark_sd_fn(
generated_imgs = []
if seed == -1:
seed = randint(0, sys.maxsize)
progress(None, desc=f"Generating...")

for current_batch in range(batch_count):
start_time = time.time()
out_imgs = global_obj.get_sd_obj().generate_images(**submit_run_kwargs)
Expand Down
11 changes: 4 additions & 7 deletions apps/shark_studio/modules/shared_cmd_opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,7 @@ def is_valid_file(arg):
"--prompt",
nargs="+",
default=[
"a photo taken of the front of a super-car drifting on a road near "
"mountains at high speeds with smoke coming off the tires, front "
"angle, front point of view, trees in the mountains of the "
"background, ((sharp focus))"
"A hi-res photo of a red street racer drifting around a curve on a mountain, high altitude, at night, tokyo in the background, 8k"
],
help="Text of which images to be generated.",
)
Expand All @@ -62,7 +59,7 @@ def is_valid_file(arg):
p.add_argument(
"--steps",
type=int,
default=50,
default=2,
help="The number of steps to do the sampling.",
)

Expand Down Expand Up @@ -100,7 +97,7 @@ def is_valid_file(arg):
p.add_argument(
"--guidance_scale",
type=float,
default=7.5,
default=0,
help="The value to be used for guidance scaling.",
)

Expand Down Expand Up @@ -346,7 +343,7 @@ def is_valid_file(arg):
p.add_argument(
"--batch_count",
type=int,
default=1,
default=4,
help="Number of batches to be generated with random seeds in " "single execution.",
)

Expand Down
52 changes: 27 additions & 25 deletions apps/shark_studio/web/ui/sd.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
"stabilityai/stable-diffusion-xl-base-1.0",
"stabilityai/sdxl-turbo",
]
sd_default_models.extend(get_checkpoints(model_type="scripts"))


def view_json_file(file_path):
Expand Down Expand Up @@ -200,7 +201,7 @@ def save_sd_cfg(config: dict, save_name: str):
filepath += ".json"
with open(filepath, mode="w") as f:
f.write(json.dumps(config))
return "..."
return save_name


def create_canvas(width, height):
Expand Down Expand Up @@ -284,23 +285,23 @@ def base_model_changed(base_model_id):
label="\U000026F0\U0000FE0F Base Model",
info="Select or enter HF model ID",
elem_id="custom_model",
value="stabilityai/stable-diffusion-2-1-base",
value="stabilityai/sdxl-turbo",
choices=sd_default_models,
allow_custom_value=True,
) # base_model_id
with gr.Row():
height = gr.Slider(
384,
512,
1024,
value=cmd_opts.height,
step=8,
value=512,
step=512,
label="\U00002195\U0000FE0F Height",
)
width = gr.Slider(
384,
512,
1024,
value=cmd_opts.width,
step=8,
value=512,
step=512,
label="\U00002194\U0000FE0F Width",
)
with gr.Accordion(
Expand Down Expand Up @@ -410,21 +411,21 @@ def base_model_changed(base_model_id):
seed = gr.Textbox(
value=cmd_opts.seed,
label="\U0001F331\U0000FE0F Seed",
info="An integer or a JSON list of integers, -1 for random",
info="An integer, -1 for random",
show_copy_button=True,
)
scheduler = gr.Dropdown(
elem_id="scheduler",
label="\U0001F4C5\U0000FE0F Scheduler",
info="\U000E0020", # forces same height as seed
value="EulerDiscrete",
value="EulerAncestralDiscrete",
choices=scheduler_model_map.keys(),
allow_custom_value=False,
)
with gr.Row():
steps = gr.Slider(
1,
100,
50,
value=cmd_opts.steps,
step=1,
label="\U0001F3C3\U0000FE0F Steps",
Expand Down Expand Up @@ -485,17 +486,17 @@ def base_model_changed(base_model_id):
with gr.Row():
canvas_width = gr.Slider(
label="Canvas Width",
minimum=256,
minimum=512,
maximum=1024,
value=512,
step=8,
step=512,
)
canvas_height = gr.Slider(
label="Canvas Height",
minimum=256,
minimum=512,
maximum=1024,
value=512,
step=8,
step=512,
)
make_canvas = gr.Button(
value="Make Canvas!",
Expand Down Expand Up @@ -616,7 +617,7 @@ def base_model_changed(base_model_id):
visible=False, # DEMO
)
compiled_pipeline = gr.Checkbox(
False,
True,
label="Faster txt2img (SDXL only)",
)
with gr.Row():
Expand All @@ -627,7 +628,7 @@ def base_model_changed(base_model_id):
queue=False,
show_progress=False,
)
stop_batch = gr.Button("Stop")
stop_batch = gr.Button("Stop", visible=False)
with gr.Tab(label="Config", id=102) as sd_tab_config:
with gr.Column(elem_classes=["sd-right-panel"]):
with gr.Row(elem_classes=["fill"]):
Expand All @@ -653,7 +654,7 @@ def base_model_changed(base_model_id):
if cmd_opts.configs_path
else get_configs_path()
),
height=75,
height=200,
)
with gr.Column(scale=1):
save_sd_config = gr.Button(
Expand All @@ -664,13 +665,13 @@ def base_model_changed(base_model_id):
size="sm",
components=sd_json,
)
with gr.Row():
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
show_label=False,
)
#with gr.Row():
sd_config_name = gr.Textbox(
value="Config Name",
info="Name of the file this config will be saved to.",
interactive=True,
show_label=False,
)
load_sd_config.change(
fn=load_sd_cfg,
inputs=[sd_json, load_sd_config],
Expand Down Expand Up @@ -758,6 +759,7 @@ def base_model_changed(base_model_id):
outputs=[
sd_json,
],
show_progress=False,
)

status_kwargs = dict(
Expand Down
2 changes: 2 additions & 0 deletions apps/shark_studio/web/utils/file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def get_checkpoints_path(model_type=""):
def get_checkpoints(model_type="checkpoints"):
ckpt_files = []
file_types = checkpoints_filetypes
if model_type == "scripts":
file_types = ["shark_*.py"]
if model_type == "lora":
file_types = file_types + ("*.pt", "*.bin")
for extn in file_types:
Expand Down
2 changes: 1 addition & 1 deletion setup_venv.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ else {python -m venv .\shark.venv\}
python -m pip install --upgrade pip
pip install wheel
pip install --pre -r requirements.txt
pip install --force-reinstall https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_compiler-20240528.279-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240528.279/iree_runtime-20240528.279-cp311-cp311-win_amd64.whl
pip install --force-reinstall https://github.com/nod-ai/SRT/releases/download/candidate-20240601.282/iree_compiler-20240601.282-cp311-cp311-win_amd64.whl https://github.com/nod-ai/SRT/releases/download/candidate-20240601.282/iree_runtime-20240601.282-cp311-cp311-win_amd64.whl
pip install -e .

Write-Host "Source your venv with ./shark.venv/Scripts/activate"

0 comments on commit 02285b3

Please sign in to comment.