diff --git a/src/promptflow/promptflow/_sdk/_submitter/experiment_orchestrator.py b/src/promptflow/promptflow/_sdk/_submitter/experiment_orchestrator.py index e515620f428..9ed016c85c8 100644 --- a/src/promptflow/promptflow/_sdk/_submitter/experiment_orchestrator.py +++ b/src/promptflow/promptflow/_sdk/_submitter/experiment_orchestrator.py @@ -130,11 +130,13 @@ def test( logger.info("Testing completed. See full logs at %s.", test_context.output_path.as_posix()) return test_context.node_results - def _test_node(self, node, test_context) -> Run: + def _test_node(self, node, test_context): if node.type == ExperimentNodeType.FLOW: return self._test_flow_node(node, test_context) elif node.type == ExperimentNodeType.COMMAND: return self._test_command_node(node, test_context) + elif node.type == ExperimentNodeType.CHAT_GROUP: + return self._test_chat_group_node(node, test_context) raise ExperimentValueError(f"Unknown experiment node {node.name!r} type {node.type!r}") def _test_flow_node(self, node, test_context): @@ -166,6 +168,14 @@ def _test_flow_node(self, node, test_context): def _test_command_node(self, *args, **kwargs): raise NotImplementedError + def _test_chat_group_node(self, node, test_context): + from promptflow._sdk.entities._chat_group._chat_group import ChatGroup + + chat_group = ChatGroup._from_node(node, test_context) + logger.debug(f"Invoking chat group node {node.name!r}.") + chat_group.invoke() + return chat_group.conversation_history + def start(self, nodes=None, from_nodes=None, attempt=None, **kwargs): """Start an execution of nodes. diff --git a/src/promptflow/promptflow/_sdk/entities/_chat_group/_chat_group.py b/src/promptflow/promptflow/_sdk/entities/_chat_group/_chat_group.py index 45b8024659a..138112199fa 100644 --- a/src/promptflow/promptflow/_sdk/entities/_chat_group/_chat_group.py +++ b/src/promptflow/promptflow/_sdk/entities/_chat_group/_chat_group.py @@ -119,6 +119,7 @@ def invoke(self): chat_round = 0 chat_token = 0 chat_start_time = time.time() + self._conversation_history = [] while True: chat_round += 1 @@ -160,6 +161,8 @@ def _get_role_input_values(self, role: ChatRole) -> Dict[str, Any]: # initializing the chat role. if value == "${parent.conversation_history}": value = self._conversation_history + elif isinstance(value, str) and value.startswith("${"): + raise ChatGroupError(f"Unresolved input value {value!r} for role {role.role!r}.") input_values[key] = value logger.debug(f"Input values for role {role.role!r}: {input_values!r}") return input_values @@ -210,3 +213,24 @@ def _check_continue_condition(self, chat_round: int, chat_token: int, chat_start def _predict_next_role_with_llm(self) -> ChatRole: """Predict next role for non-deterministic speak order.""" raise NotImplementedError(f"Speak order {self._speak_order} is not supported yet.") + + @classmethod + def _from_node(cls, node: "ChatGroupNode", context: "ExperimentTemplateTestContext"): + """Create a chat group from a chat group node.""" + logger.debug(f"Creating chat group instance from chat group node {node.name!r}...") + roles = [ChatRole(flow=role.pop("path"), **role) for role in node.roles] + chat_group = cls( + roles=roles, + max_turns=node.max_turns, + max_tokens=node.max_tokens, + max_time=node.max_time, + stop_signal=node.stop_signal, + ) + logger.debug(f"Updating role inputs for chat group {node.name!r}.") + chat_group._update_role_inputs(context) + return chat_group + + def _update_role_inputs(self, context: "ExperimentTemplateTestContext"): + """Update role inputs with context.""" + for role in self._roles: + role._update_inputs_from_data_and_inputs(data=context.test_data, inputs=context.test_inputs) diff --git a/src/promptflow/promptflow/_sdk/entities/_chat_group/_chat_role.py b/src/promptflow/promptflow/_sdk/entities/_chat_group/_chat_role.py index e259500478c..34d272a8b05 100644 --- a/src/promptflow/promptflow/_sdk/entities/_chat_group/_chat_role.py +++ b/src/promptflow/promptflow/_sdk/entities/_chat_group/_chat_role.py @@ -95,6 +95,24 @@ def _build_role_io(self, flow: Union[str, PathLike], inputs_value: Dict = None): ) return ChatRoleInputs(inputs), ChatRoleOutputs(outputs) + def _update_inputs_from_data_and_inputs(self, data: Dict, inputs: Dict): + """Update inputs from data and inputs from experiment""" + data_prefix = "${data." + inputs_prefix = "${inputs." + for key in self._inputs: + current_input = self._inputs[key] + value = current_input["value"] + if isinstance(value, str): + if value.startswith(data_prefix): + stripped_value = value.replace(data_prefix, "").replace("}", "") + data_name, col_name = stripped_value.split(".") + if data_name in data and col_name in data[data_name]: + current_input["value"] = data[data_name][col_name] + elif value.startswith(inputs_prefix): + input_name = value.replace(inputs_prefix, "").replace("}", "") + if input_name in inputs and input_name in inputs: + current_input["value"] = inputs[input_name] + def invoke(self, *args, **kwargs): """Invoke chat role""" if args: diff --git a/src/promptflow/tests/sdk_cli_test/e2etests/test_experiment.py b/src/promptflow/tests/sdk_cli_test/e2etests/test_experiment.py index 70a5fcc47ef..36d16861ed3 100644 --- a/src/promptflow/tests/sdk_cli_test/e2etests/test_experiment.py +++ b/src/promptflow/tests/sdk_cli_test/e2etests/test_experiment.py @@ -16,7 +16,7 @@ from promptflow._sdk._errors import ExperimentValueError, RunOperationError from promptflow._sdk._load_functions import _load_experiment, load_common from promptflow._sdk._pf_client import PFClient -from promptflow._sdk._submitter.experiment_orchestrator import ExperimentOrchestrator +from promptflow._sdk._submitter.experiment_orchestrator import ExperimentOrchestrator, ExperimentTemplateTestContext from promptflow._sdk.entities._experiment import CommandNode, Experiment, ExperimentTemplate, FlowNode TEST_ROOT = Path(__file__).parent.parent.parent @@ -320,3 +320,17 @@ def test_experiment_with_chat_group(self, pf: PFClient): else: exp = pf._experiments.get(exp.name) exp = ExperimentOrchestrator(pf, exp).start() + + @pytest.mark.usefixtures("use_secrets_config_file", "recording_injection", "setup_local_connection") + def test_experiment_test_chat_group_node(self, pf: PFClient): + template_path = EXP_ROOT / "chat-group-node-exp-template" / "exp.yaml" + template = load_common(ExperimentTemplate, source=template_path) + orchestrator = ExperimentOrchestrator(pf) + test_context = ExperimentTemplateTestContext(template=template) + chat_group_node = template.nodes[0] + assert chat_group_node.name == "multi_turn_chat" + + history = orchestrator._test_node(chat_group_node, test_context) + assert len(history) == 4 + assert history[0][0] == history[2][0] == "assistant" + assert history[1][0] == history[3][0] == "user"