Skip to content

Commit

Permalink
* fix error in i2i detection.
Browse files Browse the repository at this point in the history
* check for lingering special character sequences.
* small refactorings.
  • Loading branch information
acorderob committed Nov 23, 2024
1 parent 3810320 commit ee7ba94
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 48 deletions.
99 changes: 53 additions & 46 deletions ppp.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import time
from collections import namedtuple
from enum import Enum
from typing import Callable, Optional
from typing import Any, Callable, Optional

import lark
import lark.parsers
Expand Down Expand Up @@ -62,8 +62,8 @@ def __init__(
self,
logger: logging.Logger,
interrupt: Optional[Callable],
env_info: dict[str, any],
options: Optional[dict[str, any]] = None,
env_info: dict[str, Any],
options: Optional[dict[str, Any]] = None,
grammar_content: Optional[str] = None,
wildcards_obj: PPPWildcards = None,
):
Expand All @@ -80,7 +80,7 @@ def __init__(
"""
self.logger = logger
self.rng = np.random.default_rng() # gets seeded on each process prompt call
self.the_interrupt = interrupt
self.interrupt_callback = interrupt
self.options = options
self.env_info = env_info
self.wildcard_obj = wildcards_obj
Expand Down Expand Up @@ -150,10 +150,10 @@ def __init__(
self.user_variables = {}

def interrupt(self):
if self.the_interrupt is not None:
self.the_interrupt()
if self.interrupt_callback is not None:
self.interrupt_callback()

def formatOutput(self, text: str) -> str:
def format_output(self, text: str) -> str:
"""
Formats the output text by encoding it using unicode_escape and decoding it using utf-8.
Expand All @@ -165,7 +165,7 @@ def formatOutput(self, text: str) -> str:
"""
return text.encode("unicode_escape").decode("utf-8")

def isComfyUI(self) -> bool:
def is_comfy_ui(self) -> bool:
"""
Checks if the current environment is ComfyUI.
Expand Down Expand Up @@ -411,8 +411,8 @@ def __processprompts(self, prompt, negative_prompt):

# Insertions in the negative prompt
if self.debug_level == DEBUG_LEVEL.full:
self.logger.debug(self.formatOutput(f"New negative additions: {p_processor.add_at}"))
self.logger.debug(self.formatOutput(f"New negative indexes: {n_processor.insertion_at}"))
self.logger.debug(self.format_output(f"New negative additions: {p_processor.add_at}"))
self.logger.debug(self.format_output(f"New negative indexes: {n_processor.insertion_at}"))
negative_prompt = self.__add_to_insertion_points(
negative_prompt, p_processor.add_at["insertion_point"], n_processor.insertion_at
)
Expand All @@ -436,9 +436,9 @@ def __processprompts(self, prompt, negative_prompt):
ppwl = ", ".join(p_processor.detectedWildcards)
npwl = ", ".join(n_processor.detectedWildcards)
if foundP:
self.logger.error(self.formatOutput(f"In the positive prompt: {ppwl}"))
self.logger.error(self.format_output(f"In the positive prompt: {ppwl}"))
if foundNP:
self.logger.error(self.formatOutput(f"In the negative prompt: {npwl}"))
self.logger.error(self.format_output(f"In the negative prompt: {npwl}"))
if self.wil_ifwildcards == self.IFWILDCARDS_CHOICES.warn:
prompt = self.WILDCARD_WARNING + prompt
elif self.wil_ifwildcards == self.IFWILDCARDS_CHOICES.stop:
Expand All @@ -448,6 +448,13 @@ def __processprompts(self, prompt, negative_prompt):
if foundNP:
negative_prompt = self.WILDCARD_STOP.format(npwl) + negative_prompt
self.interrupt()
# Check for special character sequences that should not be in the result
compound_prompt = prompt + "\n" + negative_prompt
found_sequences = re.findall(r"::|\$\$|\$\{|[{|}]", compound_prompt)
if len(found_sequences) > 0:
self.logger.warning(
f"""Found probably invalid character sequences on the result ({', '.join(map(lambda x: '"' + x + '"', set(found_sequences)))}). Something might be wrong!"""
)
return prompt, negative_prompt

def process_prompt(
Expand Down Expand Up @@ -477,14 +484,14 @@ def process_prompt(
if self.debug_level != DEBUG_LEVEL.none:
self.logger.info(f"System variables: {self.system_variables}")
self.logger.info(f"Input seed: {seed}")
self.logger.info(self.formatOutput(f"Input prompt: {prompt}"))
self.logger.info(self.formatOutput(f"Input negative_prompt: {negative_prompt}"))
self.logger.info(self.format_output(f"Input prompt: {prompt}"))
self.logger.info(self.format_output(f"Input negative_prompt: {negative_prompt}"))
t1 = time.time()
prompt, negative_prompt = self.__processprompts(prompt, negative_prompt)
t2 = time.time()
if self.debug_level != DEBUG_LEVEL.none:
self.logger.info(self.formatOutput(f"Result prompt: {prompt}"))
self.logger.info(self.formatOutput(f"Result negative_prompt: {negative_prompt}"))
self.logger.info(self.format_output(f"Result prompt: {prompt}"))
self.logger.info(self.format_output(f"Result negative_prompt: {negative_prompt}"))
self.logger.info(f"Process prompt pair time: {t2 - t1:.3f} seconds")

# Check for constructs not processed due to parsing problems
Expand Down Expand Up @@ -514,7 +521,7 @@ def parse_prompt(self, prompt_description: str, prompt: str, parser: lark.Lark,
t1 = time.time()
try:
if self.debug_level == DEBUG_LEVEL.full:
self.logger.debug(self.formatOutput(f"Parsing {prompt_description}: '{prompt}'"))
self.logger.debug(self.format_output(f"Parsing {prompt_description}: '{prompt}'"))
parsed_prompt = parser.parse(prompt)
# we store the contents so we can use them later even if the meta position is not valid anymore
for n in parsed_prompt.iter_subtrees():
Expand All @@ -526,7 +533,7 @@ def parse_prompt(self, prompt_description: str, prompt: str, parser: lark.Lark,
except lark.exceptions.UnexpectedInput:
if raise_parsing_error:
raise
self.logger.exception(self.formatOutput(f"Parsing failed on prompt!: {prompt}"))
self.logger.exception(self.format_output(f"Parsing failed on prompt!: {prompt}"))
t2 = time.time()
if self.debug_level == DEBUG_LEVEL.full:
self.logger.debug("Tree:\n" + textwrap.indent(re.sub(r"\n$", "", parsed_prompt.pretty()), " "))
Expand Down Expand Up @@ -707,7 +714,7 @@ def __debug_end(self, construct: str, start_result: str, duration: float, info=N
if output != "":
output = f" >> '{output}'"
self.__ppp.logger.debug(
self.__ppp.formatOutput(f"TreeProcessor.{construct} {info}({duration:.3f} seconds){output}")
self.__ppp.format_output(f"TreeProcessor.{construct} {info}({duration:.3f} seconds){output}")
)

def __eval_condition(self, cond_var: str, cond_comp: str, cond_value: str | list[str]) -> bool:
Expand Down Expand Up @@ -823,7 +830,7 @@ def promptcomp(self, tree: lark.Tree):
"""
Process a prompt composition construct in the tree.
"""
# if self.__ppp.isComfyUI():
# if self.__ppp.is_comfy_ui():
# self.__ppp.logger.warning("Prompt composition is not supported in ComfyUI.")
start_result = self.result
t1 = time.time()
Expand Down Expand Up @@ -852,7 +859,7 @@ def scheduled(self, tree: lark.Tree):
"""
Process a scheduling construct in the tree and add it to the accumulated shell.
"""
# if self.__ppp.isComfyUI():
# if self.__ppp.is_comfy_ui():
# self.__ppp.logger.warning("Prompt scheduling is not supported in ComfyUI.")
start_result = self.result
t1 = time.time()
Expand Down Expand Up @@ -888,7 +895,7 @@ def alternate(self, tree: lark.Tree):
"""
Process an alternation construct in the tree and add it to the accumulated shell.
"""
# if self.__ppp.isComfyUI():
# if self.__ppp.is_comfy_ui():
# self.__ppp.logger.warning("Prompt alternation is not supported in ComfyUI.")
start_result = self.result
t1 = time.time()
Expand Down Expand Up @@ -944,7 +951,7 @@ def attention(self, tree: lark.Tree):
weight = math.floor(weight * 100) / 100 # we round to 2 decimals
weight_str = f"{weight:.2f}".rstrip("0").rstrip(".")
self.__shell.append(self.AccumulatedShell("at", weight))
if weight == 0.9 and not self.__ppp.isComfyUI():
if weight == 0.9 and not self.__ppp.is_comfy_ui():
starttag = "["
self.result += starttag
self.__visit(current_tree)
Expand Down Expand Up @@ -1227,7 +1234,7 @@ def __get_choices(
repeating = False
if self.__ppp.debug_level == DEBUG_LEVEL.full:
self.__ppp.logger.debug(
self.__ppp.formatOutput(
self.__ppp.format_output(
f"Selecting {'repeating ' if repeating else ''}{num_choices} choice"
+ (f"s and separating with '{separator}'" if num_choices > 1 else "")
)
Expand All @@ -1241,8 +1248,8 @@ def __get_choices(
for i, c in enumerate(filtered_choice_values):
c["choice_index"] = i # we index them to later sort the results
weight = float(c.get("weight", 1.0))
theif = c.get("if", None)
if weight > 0 and (theif is None or self.__evaluate_if(theif)):
condition = c.get("if", None)
if weight > 0 and (condition is None or self.__evaluate_if(condition)):
available_choices.append(c)
weights.append(weight)
included_choices += 1
Expand Down Expand Up @@ -1305,12 +1312,12 @@ def __convert_choices_options(self, options: Optional[lark.Tree]) -> dict:
"""
if options is None:
return None
the_options = {}
options_dict = {}
if len(options.children) == 1:
the_options["sampler"] = options.children[0] if options.children[0] is not None else "~"
options_dict["sampler"] = options.children[0] if options.children[0] is not None else "~"
else:
the_options["sampler"] = options.children[0].children[0] if options.children[0] is not None else "~"
the_options["repeating"] = (
options_dict["sampler"] = options.children[0].children[0] if options.children[0] is not None else "~"
options_dict["repeating"] = (
options.children[1].children[0] == "r" if options.children[1] is not None else False
)
if len(options.children) == 4:
Expand All @@ -1321,16 +1328,16 @@ def __convert_choices_options(self, options: Optional[lark.Tree]) -> dict:
ifrom = 2
ito = 3
isep = 4
the_options["from"] = (
options_dict["from"] = (
int(options.children[ifrom].children[0]) if options.children[ifrom] is not None else 1
)
the_options["to"] = int(options.children[ito].children[0]) if options.children[ito] is not None else 1
the_options["separator"] = (
options_dict["to"] = int(options.children[ito].children[0]) if options.children[ito] is not None else 1
options_dict["separator"] = (
self.__visit(options.children[isep], False, True)
if options.children[isep] is not None
else self.__ppp.wil_choice_separator
)
return the_options
return options_dict

def __convert_choice(self, choice: lark.Tree) -> dict:
"""
Expand All @@ -1342,17 +1349,17 @@ def __convert_choice(self, choice: lark.Tree) -> dict:
Returns:
dict: The converted choice.
"""
the_choice = {}
choice_dict = {}
c_label_obj = choice.children[0]
the_choice["labels"] = (
choice_dict["labels"] = (
[x.value.lower() for x in c_label_obj.children[1:-1]] # should be a token
if c_label_obj is not None
else []
)
the_choice["weight"] = float(choice.children[1].children[0]) if choice.children[1] is not None else 1.0
the_choice["if"] = choice.children[2].children[0] if choice.children[2] is not None else None
the_choice["content"] = choice.children[3]
return the_choice
choice_dict["weight"] = float(choice.children[1].children[0]) if choice.children[1] is not None else 1.0
choice_dict["if"] = choice.children[2].children[0] if choice.children[2] is not None else None
choice_dict["content"] = choice.children[3]
return choice_dict

def __check_wildcard_initialization(self, wildcard: PPPWildcard):
"""
Expand Down Expand Up @@ -1414,15 +1421,15 @@ def __check_wildcard_initialization(self, wildcard: PPPWildcard):
for cv in wildcard.unprocessed_choices[n:]:
if isinstance(cv, dict):
if self.__ppp.wildcard_obj.is_dict_choice_options(cv):
theif = cv.get("if", None)
if theif is not None and isinstance(theif, str):
condition = cv.get("if", None)
if condition is not None and isinstance(condition, str):
try:
cv["if"] = self.__ppp.parse_prompt(
"condition", theif, self.__ppp.parser_condition, True
"condition", condition, self.__ppp.parser_condition, True
)
except lark.exceptions.UnexpectedInput as e:
self.__ppp.logger.warning(
f"Error parsing condition '{theif}' in wildcard '{wildcard.key}'! : {e.__class__.__name__}"
f"Error parsing condition '{condition}' in wildcard '{wildcard.key}'! : {e.__class__.__name__}"
)
cv["if"] = None
content = cv.get("content", cv.get("text", None))
Expand Down Expand Up @@ -1632,7 +1639,7 @@ def start(self, tree):
self.__already_processed.append(content)
if self.__ppp.debug_level == DEBUG_LEVEL.full:
self.__ppp.logger.debug(
self.__ppp.formatOutput(f"Adding content at position {position}: {content}")
self.__ppp.format_output(f"Adding content at position {position}: {content}")
)
if position == "e":
self.add_at["end"].append(content)
Expand All @@ -1642,6 +1649,6 @@ def start(self, tree):
else: # position == "s" or invalid
self.add_at["start"].append(content)
else:
self.__ppp.logger.warning(self.__ppp.formatOutput(f"Ignoring repeated content: {content}"))
self.__ppp.logger.warning(self.__ppp.format_output(f"Ignoring repeated content: {content}"))
t2 = time.time()
self.__debug_end("start", "", t2 - t1)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "sd-webui-prompt-postprocessor"
description = "Stable Diffusion WebUI & ComfyUI extension to post-process the prompt, including sending content from the prompt to the negative prompt and wildcards."
version = "2.8.0"
version = "2.8.1"
license = {file = "LICENSE.txt"}
dependencies = ["lark"]

Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
lark
numpy
pyyaml
3 changes: 2 additions & 1 deletion scripts/ppp_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def process(
t1 = time.time()
if getattr(opts, "prompt_attention", "") == "Compel parser":
self.ppp_logger.warning("Compel parser is not supported!")
is_i2i = getattr(p, "init_images", [None])[0] is not None
init_images = getattr(p, "init_images", [None]) or [None]
is_i2i = bool(init_images[0])
self.ppp_debug_level = DEBUG_LEVEL(getattr(opts, "ppp_gen_debug_level", DEBUG_LEVEL.none.value))
do_i2i = getattr(opts, "ppp_gen_doi2i", False)
if is_i2i and not do_i2i:
Expand Down

0 comments on commit ee7ba94

Please sign in to comment.