Skip to content

Commit

Permalink
fix executor comment
Browse files Browse the repository at this point in the history
  • Loading branch information
yalu4 committed Apr 22, 2024
1 parent 8c3cf7a commit d0d496c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
1 change: 1 addition & 0 deletions src/promptflow-core/promptflow/_core/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

module_logger = logging.getLogger(__name__)
STREAMING_OPTION_PARAMETER_ATTR = "_streaming_option_parameter"
INPUTS_TO_ESCAPE_PARAM_KEY = "inputs_to_escape"


# copied from promptflow.contracts.tool import ToolType
Expand Down
27 changes: 11 additions & 16 deletions src/promptflow-core/promptflow/executor/_tool_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,13 +466,7 @@ def _resolve_prompt_node(self, node: Node) -> ResolvedTool:
)
self._validate_duplicated_inputs(prompt_tpl_inputs_mapping.keys(), param_names, msg)
node.inputs = self._load_images_for_prompt_tpl(prompt_tpl_inputs_mapping, node.inputs)
# Store flow inputs list as a node input to enable tools to identify these inputs,
# and apply escape/unescape to avoid parsing of role in user inputs.
inputs_to_escape = self._get_inputs_to_escape(node)
if inputs_to_escape and node.inputs:
node.inputs[INPUTS_TO_ESCAPE_PARAM_KEY] = InputAssignment(
value=inputs_to_escape, value_type=InputValueType.LITERAL
)
self._update_inputs_to_escape(node)

callable = partial(render_template_jinja2, template=prompt_tpl)
return ResolvedTool(node=node, definition=None, callable=callable, init_args={})
Expand Down Expand Up @@ -534,14 +528,20 @@ def _resolve_llm_connection_with_provider(connection):
provider = connection_type_to_api_mapping[connection_type]
return connection, provider

def _get_inputs_to_escape(self, node: Node) -> list:
def _update_inputs_to_escape(self, node: Node):
# Store flow inputs list as a node input to enable tools to identify these inputs,
# and apply escape/unescape to avoid parsing of role in user inputs.
inputs_to_escape = []
inputs = node.inputs
if node.type == ToolType.LLM or node.type == ToolType.PROMPT or node.type == ToolType.CUSTOM_LLM:
for k, v in inputs.items():
if v.value_type == InputValueType.FLOW_INPUT:
inputs_to_escape.append(k)
return inputs_to_escape

if inputs_to_escape and node.inputs:
node.inputs[INPUTS_TO_ESCAPE_PARAM_KEY] = InputAssignment(
value=inputs_to_escape, value_type=InputValueType.LITERAL
)

def _resolve_llm_node(self, node: Node, convert_input_types=False) -> ResolvedTool:
connection, provider = self._resolve_llm_connection_with_provider(self._get_llm_node_connection(node))
Expand All @@ -554,13 +554,8 @@ def _resolve_llm_node(self, node: Node, convert_input_types=False) -> ResolvedTo
updated_node.inputs[key] = InputAssignment(value=connection, value_type=InputValueType.LITERAL)
if convert_input_types:
updated_node = self._convert_node_literal_input_types(updated_node, tool)
# Store flow inputs list as a node input to enable tools to identify these inputs,
# and apply escape/unescape to avoid parsing of role in user inputs.
inputs_to_escape = self._get_inputs_to_escape(updated_node)
if inputs_to_escape:
updated_node.inputs[INPUTS_TO_ESCAPE_PARAM_KEY] = InputAssignment(
value=inputs_to_escape, value_type=InputValueType.LITERAL
)

self._update_inputs_to_escape(updated_node)

prompt_tpl = self._load_source_content(node)
prompt_tpl_inputs_mapping = get_inputs_for_prompt_template(prompt_tpl)
Expand Down

0 comments on commit d0d496c

Please sign in to comment.