Skip to content

Commit

Permalink
Merge pull request #824 from deepmodeling/zjgemi
Browse files Browse the repository at this point in the history
fix: key and group_size for merge_sliced_step
  • Loading branch information
zjgemi authored Jun 5, 2024
2 parents 2a166b8 + 3fec184 commit 53d6eb7
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
11 changes: 9 additions & 2 deletions src/dflow/plugins/dispatcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def __init__(self,
self.resources_dict = {
"number_node": 1,
"cpu_per_node": 1,
"group_size": 5,
"group_size": 1,
"envs": {
"DFLOW_WORKFLOW": "{{workflow.name}}",
"DFLOW_POD": "{{pod.name}}",
Expand Down Expand Up @@ -273,6 +273,12 @@ def render(self, template):
merge = self.merge_sliced_step and hasattr(template, "slices") and\
template.slices is not None
if merge:
assert not template.slices.sub_path, "sub_path mode of slices "\
"is incompatible with merge_sliced_step"
assert template.slices.group_size is None, "group_size of slices "\
"is incompatible with merge_sliced_step"
assert template.slices.pool_size is None, "pool_size of slices "\
"is incompatible with merge_sliced_step"
sliced_output_parameters = template.slices.output_parameter.copy()
if "dflow_success_tag" in template.outputs.parameters:
sliced_output_parameters.append("dflow_success_tag")
Expand Down Expand Up @@ -435,6 +441,8 @@ def render(self, template):
"'./%s'), exist_ok=True)\n" % path
new_template.script += " new_task_dict['backward_files']"\
".append('./%s_' + str(i))\n" % path
new_template.script += " new_task_dict['backward_files']"\
".append('log')\n" # work around no Bohrium result file
new_template.script += " with open('script' + str(i), 'w')"\
" as f:\n"
new_template.script += " f.write(new_script)\n"
Expand All @@ -444,7 +452,6 @@ def render(self, template):
"'script' + str(i)\n"
new_template.script += " tasks.append(Task.load_from_dict("\
"new_task_dict))\n"
new_template.script += "resources.group_size = 1\n"
new_template.script += "submission = Submission(work_base='.', "\
"machine=machine, resources=resources, task_list=tasks)\n"
else:
Expand Down
5 changes: 4 additions & 1 deletion src/dflow/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -1294,6 +1294,7 @@ def render_by_executor(self, context=None):
InputParameter(value=None)
self.inputs.parameters["dflow_sequence_format"] = \
InputParameter(value="")
key = "0"
if self.with_param is not None:
self.inputs.parameters["dflow_with_param"].value = \
self.with_param
Expand All @@ -1313,9 +1314,11 @@ def render_by_executor(self, context=None):
self.inputs.parameters["dflow_sequence_format"].value = \
format
self.with_sequence = None
if format is not None:
key = format % 0
if self.key is not None:
self.inputs.parameters["dflow_key"] = InputParameter(
value=str(self.key).replace("{{item}}", "merged"))
value=str(self.key).replace("{{item}}", key))
elif context is not None:
self.template = context.render(self.template)

Expand Down

0 comments on commit 53d6eb7

Please sign in to comment.