From 7aca008578e9aa190a89a97a95d96a7ec8f6783e Mon Sep 17 00:00:00 2001 From: Min Shi Date: Fri, 22 Mar 2024 16:33:30 +0800 Subject: [PATCH 1/5] make resume run aggregation node run_info override the record in previous line_results --- src/promptflow/promptflow/batch/_result.py | 16 ++++++++++------ .../tests/executor/e2etests/test_batch_engine.py | 9 +++++++++ 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/src/promptflow/promptflow/batch/_result.py b/src/promptflow/promptflow/batch/_result.py index e0c721e5cab..c761e935593 100644 --- a/src/promptflow/promptflow/batch/_result.py +++ b/src/promptflow/promptflow/batch/_result.py @@ -4,7 +4,6 @@ from dataclasses import dataclass from datetime import datetime -from itertools import chain from typing import Any, List, Mapping from promptflow._utils.exception_utils import ExceptionPresenter, RootErrorCode @@ -200,8 +199,13 @@ def _get_node_status(line_results: List[LineResult], aggr_result: AggregationRes def _get_node_run_infos(line_results: List[LineResult], aggr_result: AggregationResult): - line_node_run_infos = ( - node_run_info for line_result in line_results for node_run_info in line_result.node_run_infos.values() - ) - aggr_node_run_infos = (node_run_info for node_run_info in aggr_result.node_run_infos.values()) - return chain(line_node_run_infos, aggr_node_run_infos) + node_run_infos = {node_run_info.node: node_run_info for node_run_info in aggr_result.node_run_infos.values()} + for line_result in line_results: + for node, node_run_info in line_result.node_run_infos.items(): + # For resume runs, we only want to include the latest run info for aggregation node + # If the node in line_results is not in node_run_infos or not an aggregation node, add it to the result + if node not in node_run_infos: + yield node_run_info + # Otherwise, the node run info in node_run_infos will override the record in line_results + + yield from node_run_infos.values() diff --git a/src/promptflow/tests/executor/e2etests/test_batch_engine.py b/src/promptflow/tests/executor/e2etests/test_batch_engine.py index 21f6bffa019..f586ec097df 100644 --- a/src/promptflow/tests/executor/e2etests/test_batch_engine.py +++ b/src/promptflow/tests/executor/e2etests/test_batch_engine.py @@ -462,6 +462,11 @@ def test_batch_resume_aggregation(self, flow_folder, resume_from_run_name, dev_c for content in contents: assert content["run_info"]["root_run_id"] == resume_run_id + status_summary = {f"__pf__.nodes.{k}": v for k, v in resume_run_batch_results.node_status.items()} + assert status_summary["__pf__.nodes.grade.completed"] == 3 + assert status_summary["__pf__.nodes.calculate_accuracy.completed"] == 1 + assert status_summary["__pf__.nodes.aggregation_assert.completed"] == 1 + @pytest.mark.parametrize( "flow_folder, resume_from_run_name", [("eval_flow_with_image_resume", "eval_flow_with_image_resume_default_20240305_111258_103000")], @@ -501,3 +506,7 @@ def test_batch_resume_aggregation_with_image(self, flow_folder, resume_from_run_ contents = load_jsonl(file_path) for content in contents: assert content["run_info"]["root_run_id"] == resume_run_id + + status_summary = {f"__pf__.nodes.{k}": v for k, v in resume_run_batch_results.node_status.items()} + assert status_summary["__pf__.nodes.flip_image.completed"] == 3 + assert status_summary["__pf__.nodes.count_image.completed"] == 1 From a77d74a45674db8545aa1b3d31418a4922d6a5d5 Mon Sep 17 00:00:00 2001 From: Min Shi Date: Tue, 26 Mar 2024 20:07:59 +0800 Subject: [PATCH 2/5] recover _get_node_run_infos --- src/promptflow/promptflow/batch/_result.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/promptflow/promptflow/batch/_result.py b/src/promptflow/promptflow/batch/_result.py index c761e935593..e0c721e5cab 100644 --- a/src/promptflow/promptflow/batch/_result.py +++ b/src/promptflow/promptflow/batch/_result.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from datetime import datetime +from itertools import chain from typing import Any, List, Mapping from promptflow._utils.exception_utils import ExceptionPresenter, RootErrorCode @@ -199,13 +200,8 @@ def _get_node_status(line_results: List[LineResult], aggr_result: AggregationRes def _get_node_run_infos(line_results: List[LineResult], aggr_result: AggregationResult): - node_run_infos = {node_run_info.node: node_run_info for node_run_info in aggr_result.node_run_infos.values()} - for line_result in line_results: - for node, node_run_info in line_result.node_run_infos.items(): - # For resume runs, we only want to include the latest run info for aggregation node - # If the node in line_results is not in node_run_infos or not an aggregation node, add it to the result - if node not in node_run_infos: - yield node_run_info - # Otherwise, the node run info in node_run_infos will override the record in line_results - - yield from node_run_infos.values() + line_node_run_infos = ( + node_run_info for line_result in line_results for node_run_info in line_result.node_run_infos.values() + ) + aggr_node_run_infos = (node_run_info for node_run_info in aggr_result.node_run_infos.values()) + return chain(line_node_run_infos, aggr_node_run_infos) From deaeed0d40f002eb9687668e6adf9652dae720ee Mon Sep 17 00:00:00 2001 From: Min Shi Date: Tue, 26 Mar 2024 20:33:27 +0800 Subject: [PATCH 3/5] remove aggregation nodes in node run infos --- src/promptflow/promptflow/batch/_batch_engine.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/promptflow/promptflow/batch/_batch_engine.py b/src/promptflow/promptflow/batch/_batch_engine.py index c8123ad4527..558bdb375cd 100644 --- a/src/promptflow/promptflow/batch/_batch_engine.py +++ b/src/promptflow/promptflow/batch/_batch_engine.py @@ -242,6 +242,7 @@ def _copy_previous_run_result( """ try: previous_run_results = [] + aggregation_nodes = {node.name for node in self._flow.nodes if node.aggregation} for i in range(len(batch_inputs)): previous_run_info: FlowRunInfo = resume_from_run_storage.load_flow_run_info(i) @@ -250,8 +251,12 @@ def _copy_previous_run_result( # Thus the root_run_id needs to be the current batch run id. previous_run_info.root_run_id = run_id previous_run_info.parent_run_id = run_id - # Load previous node run info + + # Load previous node run info and remove aggregation nodes in case it is loaded into node run info previous_node_run_infos = resume_from_run_storage.load_node_run_info_for_line(i) + for node_run_info in previous_node_run_infos: + if node_run_info.node in aggregation_nodes: + previous_node_run_infos.remove(node_run_info) previous_node_run_infos_dict = {node_run.node: node_run for node_run in previous_node_run_infos} previous_node_run_outputs = { node_info.node: node_info.output for node_info in previous_node_run_infos From 5c6d7083a7e4be57680bf24ab497ebea47615a6a Mon Sep 17 00:00:00 2001 From: Min Shi Date: Wed, 27 Mar 2024 10:44:21 +0800 Subject: [PATCH 4/5] use append instead of remove --- src/promptflow/promptflow/batch/_batch_engine.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/promptflow/promptflow/batch/_batch_engine.py b/src/promptflow/promptflow/batch/_batch_engine.py index a11a9d2dc10..c84c6a9653e 100644 --- a/src/promptflow/promptflow/batch/_batch_engine.py +++ b/src/promptflow/promptflow/batch/_batch_engine.py @@ -252,9 +252,9 @@ def _copy_previous_run_result( # Load previous node run info and remove aggregation nodes in case it is loaded into node run info previous_node_run_infos = resume_from_run_storage.load_node_run_info_for_line(i) - for node_run_info in previous_node_run_infos: - if node_run_info.node in aggregation_nodes: - previous_node_run_infos.remove(node_run_info) + previous_node_run_infos = [ + run_info for run_info in previous_node_run_infos if run_info.node not in aggregation_nodes + ] previous_node_run_infos_dict = {node_run.node: node_run for node_run in previous_node_run_infos} previous_node_run_outputs = { node_info.node: node_info.output for node_info in previous_node_run_infos From 95bcad8e455fc257754c68a332b5722329f580cc Mon Sep 17 00:00:00 2001 From: Min Shi Date: Wed, 27 Mar 2024 14:48:32 +0800 Subject: [PATCH 5/5] add comment for excluding aggregation nodes in previous node run infos --- src/promptflow/promptflow/batch/_batch_engine.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/promptflow/promptflow/batch/_batch_engine.py b/src/promptflow/promptflow/batch/_batch_engine.py index c84c6a9653e..6c52fbc00af 100644 --- a/src/promptflow/promptflow/batch/_batch_engine.py +++ b/src/promptflow/promptflow/batch/_batch_engine.py @@ -250,8 +250,21 @@ def _copy_previous_run_result( previous_run_info.root_run_id = run_id previous_run_info.parent_run_id = run_id - # Load previous node run info and remove aggregation nodes in case it is loaded into node run info + # Load previous node run info previous_node_run_infos = resume_from_run_storage.load_node_run_info_for_line(i) + + # In storage, aggregation nodes are persisted with filenames similar to regular nodes. + # Currently we read regular node run records by filename in the node artifacts folder, + # which may lead to load records of aggregation nodes at the same time, which is not intended. + # E.g, aggregation-node/000000000.jsonl will be treated as the node_run_info of the first line: + # node_artifacts/ + # ├─ non-aggregation-node/ + # │ ├─ 000000000.jsonl + # │ ├─ 000000001.jsonl + # │ ├─ 000000002.jsonl + # ├─ aggregation-node/ + # │ ├─ 000000000.jsonl + # So we filter out aggregation nodes since line records should not contain any info about them. previous_node_run_infos = [ run_info for run_info in previous_node_run_infos if run_info.node not in aggregation_nodes ]