Skip to content

Commit

Permalink
escape/unescape for llm node inputs and prompt tool output
Browse files Browse the repository at this point in the history
  • Loading branch information
yalu4 committed Apr 22, 2024
1 parent dfe16a4 commit ff60b9b
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 24 deletions.
18 changes: 18 additions & 0 deletions src/promptflow-core/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,11 @@ def _resolve_prompt_node(self, node: Node) -> ResolvedTool:
)
self._validate_duplicated_inputs(prompt_tpl_inputs_mapping.keys(), param_names, msg)
node.inputs = self._load_images_for_prompt_tpl(prompt_tpl_inputs_mapping, node.inputs)

flow_input_list = self._get_flow_input_list(node)
if flow_input_list:
node.inputs["flow_inputs"] = InputAssignment(value=flow_input_list, value_type=InputValueType.LITERAL)

callable = partial(render_template_jinja2, template=prompt_tpl)
return ResolvedTool(node=node, definition=None, callable=callable, init_args={})

Expand Down Expand Up @@ -526,6 +531,15 @@ def _resolve_llm_connection_with_provider(connection):
provider = connection_type_to_api_mapping[connection_type]
return connection, provider

def _get_flow_input_list(self, node: Node):
inputs = node.inputs
flow_input_list = []
if node.type == ToolType.LLM or node.type == ToolType.PROMPT or node.type == ToolType.CUSTOM_LLM:
for k, v in inputs.items():
if v.value_type == InputValueType.FLOW_INPUT:
flow_input_list.append(k)
return flow_input_list

def _resolve_llm_node(self, node: Node, convert_input_types=False) -> ResolvedTool:
connection, provider = self._resolve_llm_connection_with_provider(self._get_llm_node_connection(node))
# Always set the provider according to the connection type
Expand All @@ -537,6 +551,10 @@ def _resolve_llm_node(self, node: Node, convert_input_types=False) -> ResolvedTo
updated_node.inputs[key] = InputAssignment(value=connection, value_type=InputValueType.LITERAL)
if convert_input_types:
updated_node = self._convert_node_literal_input_types(updated_node, tool)

flow_input_list = self._get_flow_input_list(updated_node)
if flow_input_list:
updated_node.inputs["flow_inputs"] = InputAssignment(value=flow_input_list, value_type=InputValueType.LITERAL)

prompt_tpl = self._load_source_content(node)
prompt_tpl_inputs_mapping = get_inputs_for_prompt_template(prompt_tpl)
Expand Down
59 changes: 38 additions & 21 deletions src/promptflow-tools/promptflow/tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@

GPT4V_VERSION = "vision-preview"
VALID_ROLES = ["system", "user", "assistant", "function", "tool"]
ESCAPE_DICT = {role: (lambda key: '-'.join(str(ord(c)) for c in key))(role)
for role in ["system", "user", "assistant", "function", "tool"]}


class Deployment:
Expand All @@ -56,6 +58,19 @@ def __str__(self):
return "\n".join(map(str, self))


class ExtendedStr(str):
def __init__(self, string):
super().__init__()
self.original_string = string
self.escaped_string = ""

def get_escape_str(self):
return self.escaped_string

def __str__(self):
return self.original_string


def validate_role(role: str, valid_roles: List[str] = None):
if not valid_roles:
valid_roles = VALID_ROLES
Expand Down Expand Up @@ -586,13 +601,6 @@ def render_jinja_template(prompt, trim_blocks=True, keep_trailing_newline=True,
raise JinjaTemplateError(message=error_message) from e


def build_escape_dict(kwargs: dict):
escape_dict = {}
for _, value in kwargs.items():
escape_dict = _build_escape_dict(value, escape_dict)
return escape_dict


def _build_escape_dict(val, escape_dict: dict):
"""
Build escape dictionary with roles as keys and uuids as values.
Expand All @@ -613,21 +621,23 @@ def _build_escape_dict(val, escape_dict: dict):
return escape_dict


def escape_roles(val, escape_dict: dict):
def escape_roles(val):
"""
Escape the roles in the prompt inputs to avoid the input string with pattern '# role' get parsed.
"""
if isinstance(val, ChatInputList):
return ChatInputList([escape_roles(item, escape_dict) for item in val])
return ChatInputList([escape_roles(item) for item in val])
elif isinstance(val, str):
for role, encoded_role in escape_dict.items():
val = val.replace(role, encoded_role)
pattern = r"(?i)^\s*#?\s*(" + "|".join(VALID_ROLES) + r")\s*:\s*\n"
roles = re.findall(pattern, val, flags=re.MULTILINE)
for role in roles:
val = val.replace(role, ESCAPE_DICT[role.lower()])
return val
else:
return val


def unescape_roles(val, escape_dict: dict):
def unescape_roles(val):
"""
Unescape the roles in the parsed chat messages to restore the original role names.
Expand All @@ -643,13 +653,13 @@ def unescape_roles(val, escape_dict: dict):
}]
"""
if isinstance(val, str):
for role, encoded_role in escape_dict.items():
for role, encoded_role in ESCAPE_DICT.items():
val = val.replace(encoded_role, role)
return val
elif isinstance(val, list):
for index, item in enumerate(val):
if isinstance(item, dict) and "text" in item:
for role, encoded_role in escape_dict.items():
for role, encoded_role in ESCAPE_DICT.items():
val[index]["text"] = item["text"].replace(encoded_role, role)
return val
else:
Expand All @@ -662,24 +672,31 @@ def build_messages(
image_detail: str = 'auto',
**kwargs,
):
# Use escape/unescape to avoid unintended parsing of role in user inputs.
escape_dict = build_escape_dict(kwargs)
updated_kwargs = {
key: escape_roles(value, escape_dict) for key, value in kwargs.items()
}
flow_input_list = kwargs.pop("flow_inputs", None)
updated_kwargs = kwargs
if flow_input_list:
# Use escape/unescape to avoid unintended parsing of role in user inputs.
# 1. Do escape/unescape for llm node inputs.
updated_kwargs = {
key: escape_roles(value) if key in flow_input_list else value for key, value in kwargs.items()
}
# 2. Do escape/unescape for prompt tool outputs.
updated_kwargs = {
key: value.get_escape_str() if isinstance(value, ExtendedStr) else value for key, value in updated_kwargs.items()
}

# keep_trailing_newline=True is to keep the last \n in the prompt to avoid converting "user:\t\n" to "user:".
chat_str = render_jinja_template(
prompt, trim_blocks=True, keep_trailing_newline=True, **updated_kwargs
)
messages = parse_chat(chat_str, images=images, image_detail=image_detail)

if escape_dict and isinstance(messages, list):
if flow_input_list and isinstance(messages, list):
for message in messages:
if not isinstance(message, dict):
continue
for key, val in message.items():
message[key] = unescape_roles(val, escape_dict)
message[key] = unescape_roles(val)

return messages

Expand Down
18 changes: 15 additions & 3 deletions src/promptflow-tools/promptflow/tools/template_rendering.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,21 @@
# Avoid circular dependencies: Use import 'from promptflow._internal' instead of 'from promptflow'
# since the code here is in promptflow namespace as well
from promptflow._internal import tool
from promptflow.tools.common import render_jinja_template
from promptflow.tools.common import render_jinja_template, ExtendedStr, escape_roles


@tool
def render_template_jinja2(template: str, **kwargs) -> str:
return render_jinja_template(template, trim_blocks=True, keep_trailing_newline=True, **kwargs)
def render_template_jinja2(template: str, **kwargs) -> ExtendedStr:
flow_input_list = kwargs.pop("flow_inputs", None)
updated_kwargs = kwargs
if flow_input_list:
# Use escape/unescape to avoid unintended parsing of role in user inputs.
updated_kwargs = {
key: escape_roles(value) if key in flow_input_list else value for key, value in kwargs.items()
}

original_str = render_jinja_template(template, trim_blocks=True, keep_trailing_newline=True, **kwargs)
escape_str = render_jinja_template(template, trim_blocks=True, keep_trailing_newline=True, **updated_kwargs)
res = ExtendedStr(original_str)
res.escaped_string = escape_str
return res

0 comments on commit ff60b9b

Please sign in to comment.