Skip to content

Commit

Permalink
fix: modify self after all processes have finished to avoid competition
Browse files Browse the repository at this point in the history
Signed-off-by: zjgemi <[email protected]>
  • Loading branch information
zjgemi committed Aug 5, 2024
1 parent 778c948 commit f542d3f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 14 deletions.
22 changes: 15 additions & 7 deletions src/dflow/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -1585,14 +1585,15 @@ def handle_expr(val, scope):
time.sleep(config["debug_batch_interval"])

failed = []
results = {}
for future in concurrent.futures.as_completed(futures):
j = futures.index(future)
try:
phase, pars, arts = future.result()
except Exception:
import traceback
traceback.print_exc()
self.parallel_steps[j].phase = "Failed"
results[j] = ("Failed", pars, arts)
if not self.continue_on_failed:
self.phase = "Failed"
if config["debug_failfast"]:
Expand All @@ -1601,14 +1602,21 @@ def handle_expr(val, scope):
else:
failed.append(self.parallel_steps[j])
else:
for name, value in pars.items():
self.parallel_steps[j].outputs.parameters[
name].value = value
for name, path in arts.items():
self.parallel_steps[j].outputs.artifacts[
name].local_path = path
results[j] = (phase, pars, arts)
logging.info("Outputs of %s collected" %
self.parallel_steps[j])

# modify self after all processes have finished to avoid
# competition
for j in results:
phase, pars, arts = results[j]
self.parallel_steps[j].phase = phase
for name, value in pars.items():
self.parallel_steps[j].outputs.parameters[
name].value = value
for name, path in arts.items():
self.parallel_steps[j].outputs.artifacts[
name].local_path = path
if len(failed) > 0:
raise RuntimeError("Step %s failed" % failed)

Expand Down
20 changes: 13 additions & 7 deletions src/dflow/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,24 +268,30 @@ def run(self, workflow_id=None, context=None):
"batch" % config["debug_batch_interval"])
time.sleep(config["debug_batch_interval"])

results = {}
for future in concurrent.futures.as_completed(futures):
j = futures.index(future)
try:
phase, pars, arts = future.result()
except Exception:
import traceback
traceback.print_exc()
step[j].phase = "Failed"
results[j] = ("Failed", pars, arts)
if not step[j].continue_on_failed:
raise RuntimeError("Step %s failed" % step[j])
else:
for name, value in pars.items():
step[j].outputs.parameters[
name].value = value
for name, path in arts.items():
step[j].outputs.artifacts[
name].local_path = path
results[j] = (phase, pars, arts)
logging.info("Outputs of %s collected" % step[j])

# modify self after all processes have finished to avoid
# competition
for j in results:
phase, pars, arts = results[j]
step[j].phase = phase
for name, value in pars.items():
step[j].outputs.parameters[name].value = value
for name, path in arts.items():
step[j].outputs.artifacts[name].local_path = path
else:
step.run(self, context)

Expand Down

0 comments on commit f542d3f

Please sign in to comment.