Skip to content

Commit

Permalink
Merge pull request #772 from deepmodeling/zjgemi
Browse files Browse the repository at this point in the history
fix: add create_dir to slices
  • Loading branch information
zjgemi authored Mar 4, 2024
2 parents 632ed99 + 74c6409 commit 81d01e8
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 21 deletions.
19 changes: 17 additions & 2 deletions src/dflow/python/python_op_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class Slices:
to handle each slice, 1 for serial, -1 for infinity (i.e. equals to
the number of slices)
register_first_only: only register first slice when lineage used
create_dir: create a separate dir for each slice for saving output
artifacts
"""

def __init__(
Expand All @@ -60,6 +62,7 @@ def __init__(
random_seed: int = 0,
pool_size: Optional[int] = None,
register_first_only: bool = False,
create_dir: bool = False,
) -> None:
self.input_parameter = input_parameter if input_parameter is not \
None else []
Expand All @@ -81,6 +84,7 @@ def __init__(
self.random_seed = random_seed
self.pool_size = pool_size
self.register_first_only = register_first_only
self.create_dir = create_dir

def evalable_repr(self, imports):
kwargs = {}
Expand Down Expand Up @@ -209,6 +213,7 @@ def __init__(self,
pre_script: str = "",
post_script: str = "",
success_tag: bool = False,
output_slice_dir: Dict[str, str] = None,
) -> None:
self.n_parts = {}
self.keys_of_parts = {}
Expand Down Expand Up @@ -371,6 +376,8 @@ def __init__(self,
else output_artifact_slices
self.output_parameter_slices = {} if output_parameter_slices is None \
else output_parameter_slices
self.output_slice_dir = {} if output_slice_dir is None\
else output_slice_dir
self.set_slices(slices)
self.download_method = "download"

Expand Down Expand Up @@ -405,6 +412,11 @@ def add_slices(self, slices: Slices, layer=0):
for name in slices.output_artifact:
self.output_artifact_slices[name] = slices.slices
self.outputs.artifacts[name].archive = None # no archive
if slices.create_dir:
self.output_slice_dir[name] = \
"{{inputs.parameters.dflow_slice_dir}}"
self.inputs.parameters["dflow_slice_dir"] = InputParameter(
value="")
if slices.output_parameter:
for name in slices.output_parameter:
self.output_parameter_slices[name] = slices.slices
Expand Down Expand Up @@ -620,9 +632,12 @@ def render_script(self):
for name, sign in output_sign.items():
if isinstance(sign, Artifact):
slices = self.get_slices(output_artifact_slices, name)
slice_dir = None
if name in self.output_slice_dir:
slice_dir = "'%s'" % self.output_slice_dir[name]
script += " handle_output_artifact('%s', output['%s'], "\
"output_sign['%s'], %s, r'%s')\n" % (name, name, name,
slices, self.tmp_root)
"output_sign['%s'], %s, r'%s', %s)\n" % (
name, name, name, slices, self.tmp_root, slice_dir)
else:
slices = self.get_slices(output_parameter_slices, name)
script += " handle_output_parameter('%s', output['%s'], "\
Expand Down
39 changes: 22 additions & 17 deletions src/dflow/python/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def handle_input_parameter(name, value, sign, slices=None, data_root="/tmp"):
return obj


def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp"):
def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp",
slice_dir=None):
path_list = []
if sign.type in [str, Path]:
os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True)
Expand All @@ -172,7 +173,7 @@ def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp"):
slices = 0
if value and os.path.exists(str(value)):
path_list.append({"dflow_list_item": copy_results(
value, name, data_root), "order": slices})
value, name, data_root, slice_dir), "order": slices})
else:
path_list.append({"dflow_list_item": None, "order": slices})
elif sign.type in [List[str], List[Path], Set[str], Set[Path]]:
Expand All @@ -181,32 +182,32 @@ def handle_output_artifact(name, value, sign, slices=None, data_root="/tmp"):
if isinstance(slices, int):
for path in value:
path_list.append(copy_results_and_return_path_item(
path, name, slices, data_root))
path, name, slices, data_root, slice_dir))
else:
assert len(slices) == len(value)
for path, s in zip(value, slices):
if isinstance(path, list):
for p in path:
path_list.append(
copy_results_and_return_path_item(p, name, s,
data_root))
copy_results_and_return_path_item(
p, name, s, data_root, slice_dir))
else:
path_list.append(copy_results_and_return_path_item(
path, name, s, data_root))
path, name, s, data_root, slice_dir))
else:
for s, path in enumerate(value):
path_list.append(copy_results_and_return_path_item(
path, name, s, data_root))
path, name, s, data_root, slice_dir))
elif sign.type in [Dict[str, str], Dict[str, Path]]:
os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True)
for s, path in value.items():
path_list.append(copy_results_and_return_path_item(
path, name, s, data_root))
path, name, s, data_root, slice_dir))
elif sign.type in [NestedDict[str], NestedDict[Path]]:
os.makedirs(data_root + '/outputs/artifacts/' + name, exist_ok=True)
for s, path in flatten(value).items():
path_list.append(copy_results_and_return_path_item(
path, name, s, data_root))
path, name, s, data_root, slice_dir))

os.makedirs(data_root + "/outputs/artifacts/%s/%s" % (name, config[
"catalog_dir_name"]), exist_ok=True)
Expand Down Expand Up @@ -245,15 +246,16 @@ def handle_output_parameter(name, value, sign, slices=None, data_root="/tmp"):
f.write(jsonpickle.dumps(value))


def copy_results_and_return_path_item(path, name, order, data_root="/tmp"):
def copy_results_and_return_path_item(path, name, order, data_root="/tmp",
slice_dir=None):
if path and os.path.exists(str(path)):
return {"dflow_list_item": copy_results(path, name, data_root),
"order": order}
return {"dflow_list_item": copy_results(
path, name, data_root, slice_dir), "order": order}
else:
return {"dflow_list_item": None, "order": order}


def copy_results(source, name, data_root="/tmp"):
def copy_results(source, name, data_root="/tmp", slice_dir=None):
source = str(source)
# if refer to input artifact
if source.find(data_root + "/inputs/artifacts/") == 0:
Expand All @@ -263,6 +265,8 @@ def copy_results(source, name, data_root="/tmp"):
rel_path = randstr()
else:
rel_path = source[i+1:]
if slice_dir:
rel_path = "%s/%s" % (slice_dir, rel_path)
target = data_root + "/outputs/artifacts/%s/%s" % (name, rel_path)
copy_file(source, target, shutil.copy)
if rel_path[:1] == "/":
Expand All @@ -274,11 +278,12 @@ def copy_results(source, name, data_root="/tmp"):
cwd = cwd + "/"
if source.startswith(cwd):
source = source[len(cwd):]
target = data_root + "/outputs/artifacts/%s/%s" % (name, source)
rel_path = source[1:] if source[:1] == "/" else source
if slice_dir:
rel_path = "%s/%s" % (slice_dir, rel_path)
target = data_root + "/outputs/artifacts/%s/%s" % (name, rel_path)
copy_file(source, target)
if source[:1] == "/":
source = source[1:]
return source
return rel_path


def handle_empty_dir(path):
Expand Down
24 changes: 22 additions & 2 deletions src/dflow/step.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,11 @@ def __init__(
self.inputs.parameters["dflow_slice"] = InputParameter(
value=slices.slices)

if getattr(getattr(self.template, "slices", None),
"create_dir", False):
self.inputs.parameters["dflow_slice_dir"] = InputParameter(
value="{{item}}")

sum_var = None
if isinstance(self.with_param, ArgoRange) and \
isinstance(self.with_param.end, ArgoSum):
Expand Down Expand Up @@ -540,7 +545,7 @@ def merge_step_output_artifact(art, parent, item_vars):
"dflow_artifact_key}}")
merge_output_artifact(
step.prepare_step.template.outputs.artifacts[
art.name])
art.name], None, [])
new_item_vars = []
for k, v in step.inputs.parameters.items():
for var in item_vars:
Expand Down Expand Up @@ -2349,6 +2354,9 @@ def add_slices(templ: OPTemplate, slices: Slices, layer=0):
else:
steps.append(s)

if slices.create_dir:
templ.inputs.parameters["dflow_slice_dir"] = InputParameter(value="")

for name in slices.input_parameter:
for step in steps:
for par in list(step.inputs.parameters.values()):
Expand Down Expand Up @@ -2430,6 +2438,17 @@ def stack_output_parameter(par):
def stack_output_artifact(art):
if isinstance(art, OutputArtifact):
step = art.step
if slices.create_dir:
slice_dir = "{{inputs.parameters.dflow_slice_dir}}"
value = getattr(step.inputs.parameters.get(
"dflow_slice_dir", None), "value", "")
if value:
if value.startswith(slice_dir):
slice_dir = value
else:
slice_dir += "/" + value
step.inputs.parameters["dflow_slice_dir"] = InputParameter(
value=slice_dir)
if step.template is templ:
step.inputs.parameters[slice_par] = InputParameter(
value="({{inputs.parameters.%s}} if is_outputs else None)"
Expand All @@ -2440,7 +2459,8 @@ def stack_output_artifact(art):
"{{inputs.parameters.%s}}" % slice_par_1,
output_artifact=[art.name],
sub_path=slices.sub_path,
pool_size=slices.pool_size), layer=layer+1)
pool_size=slices.pool_size, create_dir=slices.create_dir),
layer=layer+1)
step.inputs.parameters[slice_par_1] = InputParameter(
value="{{inputs.parameters.%s}}" % slice_par)

Expand Down

0 comments on commit 81d01e8

Please sign in to comment.