Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix incorrect completed count of aggregation node for resume run #2492

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions src/promptflow/promptflow/batch/_batch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,7 @@ def _copy_previous_run_result(
"""
try:
previous_run_results = []
aggregation_nodes = {node.name for node in self._flow.nodes if node.aggregation}
for i in range(len(batch_inputs)):
previous_run_info: FlowRunInfo = resume_from_run_storage.load_flow_run_info(i)

Expand All @@ -248,8 +249,25 @@ def _copy_previous_run_result(
# Thus the root_run_id needs to be the current batch run id.
previous_run_info.root_run_id = run_id
previous_run_info.parent_run_id = run_id

# Load previous node run info
previous_node_run_infos = resume_from_run_storage.load_node_run_info_for_line(i)

# In storage, aggregation nodes are persisted with filenames similar to regular nodes.
# Currently we read regular node run records by filename in the node artifacts folder,
# which may lead to load records of aggregation nodes at the same time, which is not intended.
# E.g, aggregation-node/000000000.jsonl will be treated as the node_run_info of the first line:
# node_artifacts/
# ├─ non-aggregation-node/
# │ ├─ 000000000.jsonl
# │ ├─ 000000001.jsonl
# │ ├─ 000000002.jsonl
# ├─ aggregation-node/
# │ ├─ 000000000.jsonl
# So we filter out aggregation nodes since line records should not contain any info about them.
previous_node_run_infos = [
run_info for run_info in previous_node_run_infos if run_info.node not in aggregation_nodes
]
previous_node_run_infos_dict = {node_run.node: node_run for node_run in previous_node_run_infos}
previous_node_run_outputs = {
node_info.node: node_info.output for node_info in previous_node_run_infos
Expand Down
9 changes: 9 additions & 0 deletions src/promptflow/tests/executor/e2etests/test_batch_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,11 @@ def test_batch_resume_aggregation(self, flow_folder, resume_from_run_name, dev_c
for content in contents:
assert content["run_info"]["root_run_id"] == resume_run_id

status_summary = {f"__pf__.nodes.{k}": v for k, v in resume_run_batch_results.node_status.items()}
assert status_summary["__pf__.nodes.grade.completed"] == 3
assert status_summary["__pf__.nodes.calculate_accuracy.completed"] == 1
assert status_summary["__pf__.nodes.aggregation_assert.completed"] == 1

@pytest.mark.parametrize(
"flow_folder, resume_from_run_name",
[("eval_flow_with_image_resume", "eval_flow_with_image_resume_default_20240305_111258_103000")],
Expand Down Expand Up @@ -501,3 +506,7 @@ def test_batch_resume_aggregation_with_image(self, flow_folder, resume_from_run_
contents = load_jsonl(file_path)
for content in contents:
assert content["run_info"]["root_run_id"] == resume_run_id

status_summary = {f"__pf__.nodes.{k}": v for k, v in resume_run_batch_results.node_status.items()}
assert status_summary["__pf__.nodes.flip_image.completed"] == 3
assert status_summary["__pf__.nodes.count_image.completed"] == 1
Loading