Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

re-work multi --styles-file #14707

Merged
merged 1 commit into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion modules/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
parser.add_argument("--gradio-inpaint-tool", type=str, help="does not do anything")
parser.add_argument("--gradio-allowed-path", action='append', help="add path to gradio's allowed_paths, make it possible to serve files from it", default=[data_path])
parser.add_argument("--opt-channelslast", action='store_true', help="change memory type for stable diffusion to channels last")
parser.add_argument("--styles-file", type=str, help="filename to use for styles", default=os.path.join(data_path, 'styles.csv'))
parser.add_argument("--styles-file", type=str, action='append', help="path or wildcard path of styles files, allow multiple entries.", default=[])
parser.add_argument("--autolaunch", action='store_true', help="open the webui URL in the system's default browser upon launch", default=False)
parser.add_argument("--theme", type=str, help="launches the UI with light or dark theme", default=None)
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
Expand Down
3 changes: 2 additions & 1 deletion modules/shared.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys

import gradio as gr
Expand All @@ -11,7 +12,7 @@

batch_cond_uncond = True # old field, unused now in favor of shared.opts.batch_cond_uncond
parallel_processing_allowed = True
styles_filename = cmd_opts.styles_file
styles_filename = cmd_opts.styles_file = cmd_opts.styles_file if len(cmd_opts.styles_file) > 0 else [os.path.join(data_path, 'styles.csv')]
config_filename = cmd_opts.ui_settings_file
hide_dirs = {"visible": not cmd_opts.hide_ui_dir_config}

Expand Down
85 changes: 43 additions & 42 deletions modules/styles.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
from pathlib import Path
import csv
import fnmatch
import os
import os.path
import typing
import shutil


class PromptStyle(typing.NamedTuple):
name: str
prompt: str
negative_prompt: str
path: str = None
prompt: str | None
negative_prompt: str | None
path: str | None = None


def merge_prompts(style_prompt: str, prompt: str) -> str:
Expand Down Expand Up @@ -79,14 +78,19 @@ def extract_original_prompts(style: PromptStyle, prompt, negative_prompt):


class StyleDatabase:
def __init__(self, path: str):
def __init__(self, paths: list[str | Path]):
self.no_style = PromptStyle("None", "", "", None)
self.styles = {}
self.path = path

folder, file = os.path.split(self.path)
filename, _, ext = file.partition('*')
self.default_path = os.path.join(folder, filename + ext)
self.paths = paths
self.all_styles_files: list[Path] = []

folder, file = os.path.split(self.paths[0])
if '*' in file or '?' in file:
# if the first path is a wildcard pattern, find the first match else use "folder/styles.csv" as the default path
self.default_path = next(Path(folder).glob(file), Path(os.path.join(folder, 'styles.csv')))
self.paths.insert(0, self.default_path)
else:
self.default_path = Path(self.paths[0])

self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]

Expand All @@ -99,33 +103,31 @@ def reload(self):
"""
self.styles.clear()

path, filename = os.path.split(self.path)

if "*" in filename:
fileglob = filename.split("*")[0] + "*.csv"
filelist = []
for file in os.listdir(path):
if fnmatch.fnmatch(file, fileglob):
filelist.append(file)
# Add a visible divider to the style list
half_len = round(len(file) / 2)
divider = f"{'-' * (20 - half_len)} {file.upper()}"
divider = f"{divider} {'-' * (40 - len(divider))}"
self.styles[divider] = PromptStyle(
f"{divider}", None, None, "do_not_save"
)
# Add styles from this CSV file
self.load_from_csv(os.path.join(path, file))
if len(filelist) == 0:
print(f"No styles found in {path} matching {fileglob}")
return
elif not os.path.exists(self.path):
print(f"Style database not found: {self.path}")
return
else:
self.load_from_csv(self.path)

def load_from_csv(self, path: str):
# scans for all styles files
all_styles_files = []
for pattern in self.paths:
folder, file = os.path.split(pattern)
if '*' in file or '?' in file:
found_files = Path(folder).glob(file)
[all_styles_files.append(file) for file in found_files]
else:
# if os.path.exists(pattern):
all_styles_files.append(Path(pattern))

# Remove any duplicate entries
seen = set()
self.all_styles_files = [s for s in all_styles_files if not (s in seen or seen.add(s))]

for styles_file in self.all_styles_files:
if len(all_styles_files) > 1:
# add divider when more than styles file
# '---------------- STYLES ----------------'
divider = f' {styles_file.stem.upper()} '.center(40, '-')
self.styles[divider] = PromptStyle(f"{divider}", None, None, "do_not_save")
if styles_file.is_file():
self.load_from_csv(styles_file)

def load_from_csv(self, path: str | Path):
with open(path, "r", encoding="utf-8-sig", newline="") as file:
reader = csv.DictReader(file, skipinitialspace=True)
for row in reader:
Expand All @@ -137,19 +139,19 @@ def load_from_csv(self, path: str):
negative_prompt = row.get("negative_prompt", "")
# Add style to database
self.styles[row["name"]] = PromptStyle(
row["name"], prompt, negative_prompt, path
row["name"], prompt, negative_prompt, str(path)
)

def get_style_paths(self) -> set:
"""Returns a set of all distinct paths of files that styles are loaded from."""
# Update any styles without a path to the default path
for style in list(self.styles.values()):
if not style.path:
self.styles[style.name] = style._replace(path=self.default_path)
self.styles[style.name] = style._replace(path=str(self.default_path))

# Create a list of all distinct paths, including the default path
style_paths = set()
style_paths.add(self.default_path)
style_paths.add(str(self.default_path))
for _, style in self.styles.items():
if style.path:
style_paths.add(style.path)
Expand Down Expand Up @@ -177,7 +179,6 @@ def apply_negative_styles_to_prompt(self, prompt, styles):

def save_styles(self, path: str = None) -> None:
# The path argument is deprecated, but kept for backwards compatibility
_ = path

style_paths = self.get_style_paths()

Expand Down
9 changes: 6 additions & 3 deletions modules/ui_prompt_styles.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,12 @@ def save_style(name, prompt, negative_prompt):
if not name:
return gr.update(visible=False)

style = styles.PromptStyle(name, prompt, negative_prompt)
existing_style = shared.prompt_styles.styles.get(name)
path = existing_style.path if existing_style is not None else None

style = styles.PromptStyle(name, prompt, negative_prompt, path)
shared.prompt_styles.styles[style.name] = style
shared.prompt_styles.save_styles(shared.styles_filename)
shared.prompt_styles.save_styles()

return gr.update(visible=True)

Expand All @@ -34,7 +37,7 @@ def delete_style(name):
return

shared.prompt_styles.styles.pop(name, None)
shared.prompt_styles.save_styles(shared.styles_filename)
shared.prompt_styles.save_styles()

return '', '', ''

Expand Down