Skip to content

Commit

Permalink
refine code, update test
Browse files Browse the repository at this point in the history
  • Loading branch information
yalu4 committed Apr 22, 2024
1 parent ff60b9b commit 8c3cf7a
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 159 deletions.
32 changes: 19 additions & 13 deletions src/promptflow-core/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

from promptflow._constants import MessageFormatType
from promptflow._core._errors import InvalidSource
from promptflow._core.tool import STREAMING_OPTION_PARAMETER_ATTR
from promptflow._core.tool import STREAMING_OPTION_PARAMETER_ATTR, INPUTS_TO_ESCAPE_PARAM_KEY
from promptflow._core.tools_manager import BuiltinsManager, ToolLoader, connection_type_to_api_mapping
from promptflow._utils.multimedia_utils import MultimediaProcessor
from promptflow._utils.tool_utils import (
Expand Down Expand Up @@ -466,10 +466,13 @@ def _resolve_prompt_node(self, node: Node) -> ResolvedTool:
)
self._validate_duplicated_inputs(prompt_tpl_inputs_mapping.keys(), param_names, msg)
node.inputs = self._load_images_for_prompt_tpl(prompt_tpl_inputs_mapping, node.inputs)

flow_input_list = self._get_flow_input_list(node)
if flow_input_list:
node.inputs["flow_inputs"] = InputAssignment(value=flow_input_list, value_type=InputValueType.LITERAL)
# Store flow inputs list as a node input to enable tools to identify these inputs,
# and apply escape/unescape to avoid parsing of role in user inputs.
inputs_to_escape = self._get_inputs_to_escape(node)
if inputs_to_escape and node.inputs:
node.inputs[INPUTS_TO_ESCAPE_PARAM_KEY] = InputAssignment(
value=inputs_to_escape, value_type=InputValueType.LITERAL
)

callable = partial(render_template_jinja2, template=prompt_tpl)
return ResolvedTool(node=node, definition=None, callable=callable, init_args={})
Expand Down Expand Up @@ -531,14 +534,14 @@ def _resolve_llm_connection_with_provider(connection):
provider = connection_type_to_api_mapping[connection_type]
return connection, provider

def _get_flow_input_list(self, node: Node):
def _get_inputs_to_escape(self, node: Node) -> list:
inputs_to_escape = []
inputs = node.inputs
flow_input_list = []
if node.type == ToolType.LLM or node.type == ToolType.PROMPT or node.type == ToolType.CUSTOM_LLM:
for k, v in inputs.items():
if v.value_type == InputValueType.FLOW_INPUT:
flow_input_list.append(k)
return flow_input_list
inputs_to_escape.append(k)
return inputs_to_escape

def _resolve_llm_node(self, node: Node, convert_input_types=False) -> ResolvedTool:
connection, provider = self._resolve_llm_connection_with_provider(self._get_llm_node_connection(node))
Expand All @@ -551,10 +554,13 @@ def _resolve_llm_node(self, node: Node, convert_input_types=False) -> ResolvedTo
updated_node.inputs[key] = InputAssignment(value=connection, value_type=InputValueType.LITERAL)
if convert_input_types:
updated_node = self._convert_node_literal_input_types(updated_node, tool)

flow_input_list = self._get_flow_input_list(updated_node)
if flow_input_list:
updated_node.inputs["flow_inputs"] = InputAssignment(value=flow_input_list, value_type=InputValueType.LITERAL)
# Store flow inputs list as a node input to enable tools to identify these inputs,
# and apply escape/unescape to avoid parsing of role in user inputs.
inputs_to_escape = self._get_inputs_to_escape(updated_node)
if inputs_to_escape:
updated_node.inputs[INPUTS_TO_ESCAPE_PARAM_KEY] = InputAssignment(
value=inputs_to_escape, value_type=InputValueType.LITERAL
)

prompt_tpl = self._load_source_content(node)
prompt_tpl_inputs_mapping = get_inputs_for_prompt_template(prompt_tpl)
Expand Down
84 changes: 40 additions & 44 deletions src/promptflow-tools/promptflow/tools/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys
import time
from typing import List, Mapping
import uuid

from jinja2 import Template
from openai import APIConnectionError, APIStatusError, APITimeoutError, BadRequestError, OpenAIError, RateLimitError
Expand All @@ -32,6 +31,11 @@
WrappedOpenAIError,
)

try:
from promptflow._constants import INPUTS_TO_ESCAPE_PARAM_KEY
except ImportError:
INPUTS_TO_ESCAPE_PARAM_KEY = "inputs_to_escape"

GPT4V_VERSION = "vision-preview"
VALID_ROLES = ["system", "user", "assistant", "function", "tool"]
ESCAPE_DICT = {role: (lambda key: '-'.join(str(ord(c)) for c in key))(role)
Expand All @@ -58,13 +62,13 @@ def __str__(self):
return "\n".join(map(str, self))


class ExtendedStr(str):
class PromptResult(str):
def __init__(self, string):
super().__init__()
self.original_string = string
self.escaped_string = ""

def get_escape_str(self):
def get_escape(self) -> str:
return self.escaped_string

def __str__(self):
Expand All @@ -77,9 +81,10 @@ def validate_role(role: str, valid_roles: List[str] = None):

if role not in valid_roles:
valid_roles_str = ','.join([f'\'{role}:\\n\'' for role in valid_roles])
unescaped_invalid_role = _unescape_roles(role)
error_message = (
f"The Chat API requires a specific format for prompt definition, and the prompt should include separate "
f"lines as role delimiters: {valid_roles_str}. Current parsed role '{role}'"
f"lines as role delimiters: {valid_roles_str}. Current parsed role '{unescaped_invalid_role}'"
f" does not meet the requirement. If you intend to use the Completion API, please select the appropriate"
f" API type and deployment name. If you do intend to use the Chat API, please refer to the guideline at "
f"https://aka.ms/pfdoc/chat-prompt or view the samples in our gallery that contain 'Chat' in the name."
Expand Down Expand Up @@ -595,38 +600,20 @@ def render_jinja_template(prompt, trim_blocks=True, keep_trailing_newline=True,
return Template(prompt, trim_blocks=trim_blocks, keep_trailing_newline=keep_trailing_newline).render(**kwargs)
except Exception as e:
# For exceptions raised by jinja2 module, mark UserError
print(f"Exception occurs: {type(e).__name__}: {str(e)}", file=sys.stderr)
error_message = f"Failed to render jinja template: {type(e).__name__}: {str(e)}. " \
exception_message = {str(e)}
unescaped_exception_message = _unescape_roles(exception_message)
print(f"Exception occurs: {type(e).__name__}: {unescaped_exception_message}", file=sys.stderr)
error_message = f"Failed to render jinja template: {type(e).__name__}: {unescaped_exception_message}. " \
+ "Please modify your prompt to fix the issue."
raise JinjaTemplateError(message=error_message) from e


def _build_escape_dict(val, escape_dict: dict):
"""
Build escape dictionary with roles as keys and uuids as values.
"""
if isinstance(val, ChatInputList):
for item in val:
_build_escape_dict(item, escape_dict)
elif isinstance(val, str):
pattern = r"(?i)^\s*#?\s*(" + "|".join(VALID_ROLES) + r")\s*:\s*\n"
roles = re.findall(pattern, val, flags=re.MULTILINE)
for role in roles:
if role not in escape_dict:
# We cannot use a hard-coded hash str for each role, as the same role might be in various case formats.
# For example, the 'system' role may vary in input as 'system', 'System', 'SysteM','SYSTEM', etc.
# To convert the escaped roles back to the original str, we need to use different uuids for each case.
escape_dict[role] = str(uuid.uuid4())

return escape_dict


def escape_roles(val):
def _escape_roles(val):
"""
Escape the roles in the prompt inputs to avoid the input string with pattern '# role' get parsed.
"""
if isinstance(val, ChatInputList):
return ChatInputList([escape_roles(item) for item in val])
return ChatInputList([_escape_roles(item) for item in val])
elif isinstance(val, str):
pattern = r"(?i)^\s*#?\s*(" + "|".join(VALID_ROLES) + r")\s*:\s*\n"
roles = re.findall(pattern, val, flags=re.MULTILINE)
Expand All @@ -637,7 +624,7 @@ def escape_roles(val):
return val


def unescape_roles(val):
def _unescape_roles(val):
"""
Unescape the roles in the parsed chat messages to restore the original role names.
Expand Down Expand Up @@ -666,37 +653,46 @@ def unescape_roles(val):
return val


def escape_roles_for_flow_inputs_and_prompt_output(kwargs: dict):
# Use escape/unescape to avoid unintended parsing of role in user inputs.
# There are two scenarios to consider for llm/prompt tool:
# 1. Prompt injection directly from flow input.
# 2. Prompt injection from the previous linked prompt tool, where its output becomes llm/prompt input.
inputs_to_escape = kwargs.pop(INPUTS_TO_ESCAPE_PARAM_KEY, None)
updated_kwargs = kwargs
# Scenario 1.
if inputs_to_escape:
updated_kwargs = {
key: _escape_roles(value) if key in inputs_to_escape else value for key, value in kwargs.items()
}
# Scenario 2.
converted_kwargs = {
key: value.get_escape() if isinstance(value, PromptResult) and value.get_escape() else value
for key, value in updated_kwargs.items()
}

return converted_kwargs


def build_messages(
prompt: PromptTemplate,
images: List = None,
image_detail: str = 'auto',
**kwargs,
):
flow_input_list = kwargs.pop("flow_inputs", None)
updated_kwargs = kwargs
if flow_input_list:
# Use escape/unescape to avoid unintended parsing of role in user inputs.
# 1. Do escape/unescape for llm node inputs.
updated_kwargs = {
key: escape_roles(value) if key in flow_input_list else value for key, value in kwargs.items()
}
# 2. Do escape/unescape for prompt tool outputs.
updated_kwargs = {
key: value.get_escape_str() if isinstance(value, ExtendedStr) else value for key, value in updated_kwargs.items()
}

updated_kwargs = escape_roles_for_flow_inputs_and_prompt_output(kwargs.copy())
# keep_trailing_newline=True is to keep the last \n in the prompt to avoid converting "user:\t\n" to "user:".
chat_str = render_jinja_template(
prompt, trim_blocks=True, keep_trailing_newline=True, **updated_kwargs
)
messages = parse_chat(chat_str, images=images, image_detail=image_detail)

if flow_input_list and isinstance(messages, list):
if kwargs != updated_kwargs and isinstance(messages, list):
for message in messages:
if not isinstance(message, dict):
continue
for key, val in message.items():
message[key] = unescape_roles(val)
message[key] = _unescape_roles(val)

return messages

Expand Down
25 changes: 11 additions & 14 deletions src/promptflow-tools/promptflow/tools/template_rendering.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,18 @@
# Avoid circular dependencies: Use import 'from promptflow._internal' instead of 'from promptflow'
# since the code here is in promptflow namespace as well
from promptflow._internal import tool
from promptflow.tools.common import render_jinja_template, ExtendedStr, escape_roles
from promptflow.tools.common import render_jinja_template, PromptResult, escape_roles_for_flow_inputs_and_prompt_output


@tool
def render_template_jinja2(template: str, **kwargs) -> ExtendedStr:
flow_input_list = kwargs.pop("flow_inputs", None)
updated_kwargs = kwargs
if flow_input_list:
# Use escape/unescape to avoid unintended parsing of role in user inputs.
updated_kwargs = {
key: escape_roles(value) if key in flow_input_list else value for key, value in kwargs.items()
}
def render_template_jinja2(template: str, **kwargs) -> PromptResult:
rendered_template = render_jinja_template(template, trim_blocks=True, keep_trailing_newline=True, **kwargs)
prompt_result = PromptResult(rendered_template)

original_str = render_jinja_template(template, trim_blocks=True, keep_trailing_newline=True, **kwargs)
escape_str = render_jinja_template(template, trim_blocks=True, keep_trailing_newline=True, **updated_kwargs)
res = ExtendedStr(original_str)
res.escaped_string = escape_str
return res
updated_kwargs = escape_roles_for_flow_inputs_and_prompt_output(kwargs.copy())
if kwargs != updated_kwargs:
escaped_rendered_template = render_jinja_template(
template, trim_blocks=True, keep_trailing_newline=True, **updated_kwargs
)
prompt_result.escaped_string = escaped_rendered_template
return prompt_result
Loading

0 comments on commit 8c3cf7a

Please sign in to comment.