diff --git a/src/promptflow-core/promptflow/executor/_tool_resolver.py b/src/promptflow-core/promptflow/executor/_tool_resolver.py index 9458986cf12f..ca56af59f2e3 100644 --- a/src/promptflow-core/promptflow/executor/_tool_resolver.py +++ b/src/promptflow-core/promptflow/executor/_tool_resolver.py @@ -466,6 +466,11 @@ def _resolve_prompt_node(self, node: Node) -> ResolvedTool: ) self._validate_duplicated_inputs(prompt_tpl_inputs_mapping.keys(), param_names, msg) node.inputs = self._load_images_for_prompt_tpl(prompt_tpl_inputs_mapping, node.inputs) + + flow_input_list = self._get_flow_input_list(node) + if flow_input_list: + node.inputs["flow_inputs"] = InputAssignment(value=flow_input_list, value_type=InputValueType.LITERAL) + callable = partial(render_template_jinja2, template=prompt_tpl) return ResolvedTool(node=node, definition=None, callable=callable, init_args={}) @@ -526,6 +531,15 @@ def _resolve_llm_connection_with_provider(connection): provider = connection_type_to_api_mapping[connection_type] return connection, provider + def _get_flow_input_list(self, node: Node): + inputs = node.inputs + flow_input_list = [] + if node.type == ToolType.LLM or node.type == ToolType.PROMPT or node.type == ToolType.CUSTOM_LLM: + for k, v in inputs.items(): + if v.value_type == InputValueType.FLOW_INPUT: + flow_input_list.append(k) + return flow_input_list + def _resolve_llm_node(self, node: Node, convert_input_types=False) -> ResolvedTool: connection, provider = self._resolve_llm_connection_with_provider(self._get_llm_node_connection(node)) # Always set the provider according to the connection type @@ -537,6 +551,10 @@ def _resolve_llm_node(self, node: Node, convert_input_types=False) -> ResolvedTo updated_node.inputs[key] = InputAssignment(value=connection, value_type=InputValueType.LITERAL) if convert_input_types: updated_node = self._convert_node_literal_input_types(updated_node, tool) + + flow_input_list = self._get_flow_input_list(updated_node) + if flow_input_list: + updated_node.inputs["flow_inputs"] = InputAssignment(value=flow_input_list, value_type=InputValueType.LITERAL) prompt_tpl = self._load_source_content(node) prompt_tpl_inputs_mapping = get_inputs_for_prompt_template(prompt_tpl) diff --git a/src/promptflow-tools/promptflow/tools/common.py b/src/promptflow-tools/promptflow/tools/common.py index ef63362f3c08..3f55521d9c54 100644 --- a/src/promptflow-tools/promptflow/tools/common.py +++ b/src/promptflow-tools/promptflow/tools/common.py @@ -34,6 +34,8 @@ GPT4V_VERSION = "vision-preview" VALID_ROLES = ["system", "user", "assistant", "function", "tool"] +ESCAPE_DICT = {role: (lambda key: '-'.join(str(ord(c)) for c in key))(role) + for role in ["system", "user", "assistant", "function", "tool"]} class Deployment: @@ -56,6 +58,19 @@ def __str__(self): return "\n".join(map(str, self)) +class ExtendedStr(str): + def __init__(self, string): + super().__init__() + self.original_string = string + self.escaped_string = "" + + def get_escape_str(self): + return self.escaped_string + + def __str__(self): + return self.original_string + + def validate_role(role: str, valid_roles: List[str] = None): if not valid_roles: valid_roles = VALID_ROLES @@ -586,13 +601,6 @@ def render_jinja_template(prompt, trim_blocks=True, keep_trailing_newline=True, raise JinjaTemplateError(message=error_message) from e -def build_escape_dict(kwargs: dict): - escape_dict = {} - for _, value in kwargs.items(): - escape_dict = _build_escape_dict(value, escape_dict) - return escape_dict - - def _build_escape_dict(val, escape_dict: dict): """ Build escape dictionary with roles as keys and uuids as values. @@ -613,21 +621,23 @@ def _build_escape_dict(val, escape_dict: dict): return escape_dict -def escape_roles(val, escape_dict: dict): +def escape_roles(val): """ Escape the roles in the prompt inputs to avoid the input string with pattern '# role' get parsed. """ if isinstance(val, ChatInputList): - return ChatInputList([escape_roles(item, escape_dict) for item in val]) + return ChatInputList([escape_roles(item) for item in val]) elif isinstance(val, str): - for role, encoded_role in escape_dict.items(): - val = val.replace(role, encoded_role) + pattern = r"(?i)^\s*#?\s*(" + "|".join(VALID_ROLES) + r")\s*:\s*\n" + roles = re.findall(pattern, val, flags=re.MULTILINE) + for role in roles: + val = val.replace(role, ESCAPE_DICT[role.lower()]) return val else: return val -def unescape_roles(val, escape_dict: dict): +def unescape_roles(val): """ Unescape the roles in the parsed chat messages to restore the original role names. @@ -643,13 +653,13 @@ def unescape_roles(val, escape_dict: dict): }] """ if isinstance(val, str): - for role, encoded_role in escape_dict.items(): + for role, encoded_role in ESCAPE_DICT.items(): val = val.replace(encoded_role, role) return val elif isinstance(val, list): for index, item in enumerate(val): if isinstance(item, dict) and "text" in item: - for role, encoded_role in escape_dict.items(): + for role, encoded_role in ESCAPE_DICT.items(): val[index]["text"] = item["text"].replace(encoded_role, role) return val else: @@ -662,11 +672,18 @@ def build_messages( image_detail: str = 'auto', **kwargs, ): - # Use escape/unescape to avoid unintended parsing of role in user inputs. - escape_dict = build_escape_dict(kwargs) - updated_kwargs = { - key: escape_roles(value, escape_dict) for key, value in kwargs.items() - } + flow_input_list = kwargs.pop("flow_inputs", None) + updated_kwargs = kwargs + if flow_input_list: + # Use escape/unescape to avoid unintended parsing of role in user inputs. + # 1. Do escape/unescape for llm node inputs. + updated_kwargs = { + key: escape_roles(value) if key in flow_input_list else value for key, value in kwargs.items() + } + # 2. Do escape/unescape for prompt tool outputs. + updated_kwargs = { + key: value.get_escape_str() if isinstance(value, ExtendedStr) else value for key, value in updated_kwargs.items() + } # keep_trailing_newline=True is to keep the last \n in the prompt to avoid converting "user:\t\n" to "user:". chat_str = render_jinja_template( @@ -674,12 +691,12 @@ def build_messages( ) messages = parse_chat(chat_str, images=images, image_detail=image_detail) - if escape_dict and isinstance(messages, list): + if flow_input_list and isinstance(messages, list): for message in messages: if not isinstance(message, dict): continue for key, val in message.items(): - message[key] = unescape_roles(val, escape_dict) + message[key] = unescape_roles(val) return messages diff --git a/src/promptflow-tools/promptflow/tools/template_rendering.py b/src/promptflow-tools/promptflow/tools/template_rendering.py index d5ab6a9c584a..3a8c67a4837a 100644 --- a/src/promptflow-tools/promptflow/tools/template_rendering.py +++ b/src/promptflow-tools/promptflow/tools/template_rendering.py @@ -1,9 +1,21 @@ # Avoid circular dependencies: Use import 'from promptflow._internal' instead of 'from promptflow' # since the code here is in promptflow namespace as well from promptflow._internal import tool -from promptflow.tools.common import render_jinja_template +from promptflow.tools.common import render_jinja_template, ExtendedStr, escape_roles @tool -def render_template_jinja2(template: str, **kwargs) -> str: - return render_jinja_template(template, trim_blocks=True, keep_trailing_newline=True, **kwargs) +def render_template_jinja2(template: str, **kwargs) -> ExtendedStr: + flow_input_list = kwargs.pop("flow_inputs", None) + updated_kwargs = kwargs + if flow_input_list: + # Use escape/unescape to avoid unintended parsing of role in user inputs. + updated_kwargs = { + key: escape_roles(value) if key in flow_input_list else value for key, value in kwargs.items() + } + + original_str = render_jinja_template(template, trim_blocks=True, keep_trailing_newline=True, **kwargs) + escape_str = render_jinja_template(template, trim_blocks=True, keep_trailing_newline=True, **updated_kwargs) + res = ExtendedStr(original_str) + res.escaped_string = escape_str + return res