Skip to content

Commit

Permalink
Merge pull request #779 from deepmodeling/zjgemi
Browse files Browse the repository at this point in the history
Zjgemi
  • Loading branch information
zjgemi authored Mar 18, 2024
2 parents 4afb216 + fbf5725 commit 89890b1
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions src/dflow/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,17 +371,22 @@ def wait(self, interval=1):
while self.query_status() in ["Pending", "Running"]:
time.sleep(interval)

def handle_reused_step(self, step):
def handle_reused_step(self, step, global_parameters, global_artifacts):
outputs = {}
if hasattr(step, "outputs"):
if hasattr(step.outputs, "exitCode"):
outputs["exitCode"] = step.outputs.exitCode
if hasattr(step.outputs, "parameters"):
outputs["parameters"] = []
for name, par in step.outputs.parameters.items():
if not hasattr(step.outputs.parameters[name],
"save_as_artifact"):
if not hasattr(par, "save_as_artifact"):
outputs["parameters"].append(par.recover())
if hasattr(par, "globalName") and name != \
"dflow_global":
global_par = par.recover()
global_par["name"] = par.globalName
global_par.pop("globalName", None)
global_parameters[par.globalName] = global_par
if hasattr(step.outputs, "artifacts"):
for name, art in step.outputs.artifacts.items():
group_key = step.get("inputs", {}).get(
Expand All @@ -397,6 +402,11 @@ def handle_reused_step(self, step):
else:
self.handle_reused_artifact_with_copy(
step, name, art)
if hasattr(art, "globalName"):
global_art = art.recover()
global_art["name"] = art.globalName
global_art.pop("globalName", None)
global_artifacts[art.globalName] = global_art
outputs["artifacts"] = [
art.recover() for art in step.outputs.artifacts.values()]
self.memoize_map["%s-%s" % (self.id, step.key)] = {
Expand Down Expand Up @@ -470,7 +480,8 @@ def convert_to_argo(self, reuse_step=None):
assert isinstance(self.context, (Context, Executor))
self = self.context.render(self)

status = None
global_parameters = {}
global_artifacts = {}
if reuse_step is not None:
self.reused_keys = [step.key for step in reuse_step
if step.key is not None]
Expand All @@ -484,7 +495,8 @@ def convert_to_argo(self, reuse_step=None):
if step.key is None:
continue
key2id[step.key] = step.id
self.handle_reused_step(step)
self.handle_reused_step(step, global_parameters,
global_artifacts)

for key, step in self.memoize_map.items():
data = {key: json.dumps(step)}
Expand All @@ -506,8 +518,8 @@ def convert_to_argo(self, reuse_step=None):
self.handle_template(self.entrypoint, memoize_prefix=self.id,
memoize_configmap="dflow")
if config["save_keys_in_global_outputs"]:
status = {"outputs": {"parameters": [
{"name": key, "value": id} for key, id in key2id.items()]}}
for key, id in key2id.items():
global_parameters[key] = {"name": key, "value": id}
else:
self.handle_template(self.entrypoint)

Expand Down Expand Up @@ -573,7 +585,8 @@ def convert_to_argo(self, reuse_step=None):
artifact_repository_ref=None if self.artifact_repo_key is None
else V1alpha1ArtifactRepositoryRef(key=self.artifact_repo_key)
),
status=status)
status={"outputs": {"parameters": list(global_parameters.values()),
"artifacts": list(global_artifacts.values())}})

def deduplicate_templates(self):
logger.debug("before deduplication: %s" % len(self.argo_templates))
Expand Down

0 comments on commit 89890b1

Please sign in to comment.