diff --git a/samples/basic/recursion.py b/samples/basic/recursion.py index b9057a97d5f..5bd6c652049 100644 --- a/samples/basic/recursion.py +++ b/samples/basic/recursion.py @@ -49,10 +49,6 @@ def flip_component(flip_result): # as the input whereas the flip_result in the current graph component # comes from the flipA.output in the flipcoin function. flip_component(flipA.output) - # Return a dictionary of string to arguments - # such that the downstream components that depend - # on this graph component can access the output. - return {'flip_result': flipA.output} @dsl.pipeline( name='pipeline flip coin', @@ -63,7 +59,7 @@ def flipcoin(): flip_loop = flip_component(flipA.output) # flip_loop is a graph_component with the outputs field # filled with the returned dictionary. - PrintOp('cool, it is over. %s' % flip_loop.outputs['flip_result']) + PrintOp('cool, it is over. %s' % flipA.output).after(flip_loop) if __name__ == '__main__': import kfp.compiler as compiler diff --git a/sdk/python/kfp/compiler/compiler.py b/sdk/python/kfp/compiler/compiler.py index ff91797cdca..98616bd538a 100644 --- a/sdk/python/kfp/compiler/compiler.py +++ b/sdk/python/kfp/compiler/compiler.py @@ -84,33 +84,104 @@ def _get_op_groups_helper(current_groups, ops_to_groups): _get_op_groups_helper(current_groups, ops_to_groups) return ops_to_groups + #TODO: combine with the _get_groups_for_ops + def _get_groups_for_opsgroups(self, root_group): + """Helper function to get belonging groups for each opsgroup. + + Each pipeline has a root group. Each group has a list of operators (leaf) and groups. + This function traverse the tree and get all ancestor groups for all opsgroups. + + Returns: + A dict. Key is the opsgroup's name. Value is a list of ancestor groups including the + opsgroup itself. The list of a given opsgroup is sorted in a way that the farthest + group is the first and opsgroup itself is the last. + """ + def _get_opsgroup_groups_helper(current_groups, opsgroups_to_groups): + root_group = current_groups[-1] + for g in root_group.groups: + # Add recursive opsgroup in the ops_to_groups + # such that the i/o dependency can be propagated to the ancester opsgroups + if g.recursive_ref: + continue + opsgroups_to_groups[g.name] = [x.name for x in current_groups] + [g.name] + current_groups.append(g) + _get_opsgroup_groups_helper(current_groups, opsgroups_to_groups) + del current_groups[-1] + + opsgroups_to_groups = {} + current_groups = [root_group] + _get_opsgroup_groups_helper(current_groups, opsgroups_to_groups) + return opsgroups_to_groups + def _get_groups(self, root_group): """Helper function to get all groups (not including ops) in a pipeline.""" def _get_groups_helper(group): - groups = [group] + groups = {group.name: group} for g in group.groups: # Skip the recursive opsgroup because no templates # need to be generated for the recursive opsgroups. if not g.recursive_ref: - groups += _get_groups_helper(g) + groups.update(_get_groups_helper(g)) return groups return _get_groups_helper(root_group) - def _get_uncommon_ancestors(self, op_groups, op1, op2): + def _get_uncommon_ancestors(self, op_groups, opsgroup_groups, op1, op2): """Helper function to get unique ancestors between two ops. For example, op1's ancestor groups are [root, G1, G2, G3, op1], op2's ancestor groups are [root, G1, G4, op2], then it returns a tuple ([G2, G3, op1], [G4, op2]). """ - both_groups = [op_groups[op1.name], op_groups[op2.name]] + #TODO: extract a function for the following two code module + if op1.name in op_groups: + op1_groups = op_groups[op1.name] + elif op1.name in opsgroup_groups: + op1_groups = opsgroup_groups[op1.name] + else: + raise ValueError(op1.name + ' does not exist.') + + if op2.name in op_groups: + op2_groups = op_groups[op2.name] + elif op2.name in opsgroup_groups: + op2_groups = opsgroup_groups[op2.name] + else: + raise ValueError(op1.name + ' does not exist.') + + both_groups = [op1_groups, op2_groups] common_groups_len = sum(1 for x in zip(*both_groups) if x==(x[0],)*len(x)) - group1 = op_groups[op1.name][common_groups_len:] - group2 = op_groups[op2.name][common_groups_len:] + group1 = op1_groups[common_groups_len:] + group2 = op2_groups[common_groups_len:] return (group1, group2) - def _get_inputs_outputs(self, pipeline, root_group, op_groups): + def _get_condition_params_for_ops(self, root_group): + """Get parameters referenced in conditions of ops.""" + + conditions = defaultdict(set) + + def _get_condition_params_for_ops_helper(group, current_conditions_params): + new_current_conditions_params = current_conditions_params + if group.type == 'condition': + new_current_conditions_params = list(current_conditions_params) + if isinstance(group.condition.operand1, dsl.PipelineParam): + new_current_conditions_params.append(group.condition.operand1) + if isinstance(group.condition.operand2, dsl.PipelineParam): + new_current_conditions_params.append(group.condition.operand2) + for op in group.ops: + for param in new_current_conditions_params: + conditions[op.name].add(param) + for g in group.groups: + # If the subgroup is a recursive opsgroup, propagate the pipelineparams + # in the condition expression, similar to the ops. + if g.recursive_ref: + for param in new_current_conditions_params: + conditions[g.name].add(param) + else: + _get_condition_params_for_ops_helper(g, new_current_conditions_params) + _get_condition_params_for_ops_helper(root_group, []) + return conditions + + def _get_inputs_outputs(self, pipeline, root_group, op_groups, opsgroup_groups, condition_params): """Get inputs and outputs of each group and op. Returns: @@ -120,9 +191,9 @@ def _get_inputs_outputs(self, pipeline, root_group, op_groups): produces the param. If the param is a pipeline param (no producer op), then producing_op_name is None. """ - condition_params = self._get_condition_params_for_ops(root_group) inputs = defaultdict(set) outputs = defaultdict(set) + for op in pipeline.ops.values(): # op's inputs and all params used in conditions for that op are both considered. for param in op.inputs + list(condition_params[op.name]): @@ -134,7 +205,7 @@ def _get_inputs_outputs(self, pipeline, root_group, op_groups): if param.op_name: upstream_op = pipeline.ops[param.op_name] upstream_groups, downstream_groups = self._get_uncommon_ancestors( - op_groups, upstream_op, op) + op_groups, opsgroup_groups, upstream_op, op) for i, g in enumerate(downstream_groups): if i == 0: # If it is the first uncommon downstream group, then the input comes from @@ -161,17 +232,23 @@ def _get_inputs_outputs(self, pipeline, root_group, op_groups): def _get_inputs_outputs_recursive_opsgroup(group): #TODO: refactor the following codes with the above if group.recursive_ref: - for param in group.inputs + list(condition_params[group.name]): + params = [(param, False) for param in group.inputs] + params.extend([(param, True) for param in list(condition_params[group.name])]) + for param, is_condition_param in params: if param.value: continue full_name = self._pipelineparam_full_name(param) if param.op_name: upstream_op = pipeline.ops[param.op_name] upstream_groups, downstream_groups = self._get_uncommon_ancestors( - op_groups, upstream_op, group) + op_groups, opsgroup_groups, upstream_op, group) for i, g in enumerate(downstream_groups): if i == 0: inputs[g].add((full_name, upstream_groups[0])) + # There is no need to pass the condition param as argument to the downstream ops. + #TODO: this might also apply to ops. add a TODO here and think about it. + elif i == len(downstream_groups) - 1 and is_condition_param: + continue else: inputs[g].add((full_name, None)) for i, g in enumerate(upstream_groups): @@ -189,34 +266,7 @@ def _get_inputs_outputs_recursive_opsgroup(group): _get_inputs_outputs_recursive_opsgroup(root_group) return inputs, outputs - def _get_condition_params_for_ops(self, root_group): - """Get parameters referenced in conditions of ops.""" - - conditions = defaultdict(set) - - def _get_condition_params_for_ops_helper(group, current_conditions_params): - new_current_conditions_params = current_conditions_params - if group.type == 'condition': - new_current_conditions_params = list(current_conditions_params) - if isinstance(group.condition.operand1, dsl.PipelineParam): - new_current_conditions_params.append(group.condition.operand1) - if isinstance(group.condition.operand2, dsl.PipelineParam): - new_current_conditions_params.append(group.condition.operand2) - for op in group.ops: - for param in new_current_conditions_params: - conditions[op.name].add(param) - for g in group.groups: - # If the subgroup is a recursive opsgroup, propagate the pipelineparams - # in the condition expression, similar to the ops. - if g.recursive_ref: - for param in new_current_conditions_params: - conditions[g.name].add(param) - else: - _get_condition_params_for_ops_helper(g, new_current_conditions_params) - _get_condition_params_for_ops_helper(root_group, []) - return conditions - - def _get_dependencies(self, pipeline, root_group, op_groups): + def _get_dependencies(self, pipeline, root_group, op_groups, opsgroups_groups, opsgroups, condition_params): """Get dependent groups and ops for all ops and groups. Returns: @@ -226,39 +276,48 @@ def _get_dependencies(self, pipeline, root_group, op_groups): then G3 is dependent on G2. Basically dependency only exists in the first uncommon ancesters in their ancesters chain. Only sibling groups/ops can have dependencies. """ - #TODO: move the condition_params out because both the _get_inputs_outputs - # and _get_dependencies depend on it. - condition_params = self._get_condition_params_for_ops(root_group) dependencies = defaultdict(set) for op in pipeline.ops.values(): - unstream_op_names = set() + upstream_op_names = set() for param in op.inputs + list(condition_params[op.name]): if param.op_name: - unstream_op_names.add(param.op_name) - unstream_op_names |= set(op.dependent_op_names) + upstream_op_names.add(param.op_name) + upstream_op_names |= set(op.dependent_names) + + for op_name in upstream_op_names: + # the dependent op could be either a ContainerOp or an opsgroup + if op_name in pipeline.ops: + upstream_op = pipeline.ops[op_name] + elif op_name in opsgroups: + upstream_op = opsgroups[op_name] + else: + raise ValueError('compiler cannot find the ' + op_name) - for op_name in unstream_op_names: - upstream_op = pipeline.ops[op_name] upstream_groups, downstream_groups = self._get_uncommon_ancestors( - op_groups, upstream_op, op) + op_groups, opsgroups_groups, upstream_op, op) dependencies[downstream_groups[0]].add(upstream_groups[0]) # Generate dependencies based on the recursive opsgroups #TODO: refactor the following codes with the above def _get_dependency_opsgroup(group, dependencies): + upstream_op_names = set() if group.recursive_ref: - unstream_op_names = set() for param in group.inputs + list(condition_params[group.name]): if param.op_name: - unstream_op_names.add(param.op_name) - unstream_op_names |= set(group.dependencies) + upstream_op_names.add(param.op_name) + else: + upstream_op_names = set([dependency.name for dependency in group.dependencies]) - for op_name in unstream_op_names: + for op_name in upstream_op_names: + if op_name in pipeline.ops: upstream_op = pipeline.ops[op_name] - upstream_groups, downstream_groups = self._get_uncommon_ancestors( - op_groups, upstream_op, group) - dependencies[downstream_groups[0]].add(upstream_groups[0]) - + elif op_name in opsgroups_groups: + upstream_op = opsgroups_groups[op_name] + else: + raise ValueError('compiler cannot find the ' + op_name) + upstream_groups, downstream_groups = self._get_uncommon_ancestors( + op_groups, opsgroups_groups, upstream_op, group) + dependencies[downstream_groups[0]].add(upstream_groups[0]) for subgroup in group.groups: _get_dependency_opsgroup(subgroup, dependencies) @@ -279,7 +338,12 @@ def _resolve_value_or_reference(self, value_or_reference, potential_references): task_names = [task_name for param_name, task_name in potential_references if param_name == parameter_name] if task_names: task_name = task_names[0] - return '{{tasks.%s.outputs.parameters.%s}}' % (task_name, parameter_name) + # When the task_name is None, the parameter comes directly from ancient ancesters + # instead of parents. Thus, it is resolved as the input parameter in the current group. + if task_name is None: + return '{{inputs.parameters.%s}}' % parameter_name + else: + return '{{tasks.%s.outputs.parameters.%s}}' % (task_name, parameter_name) else: return '{{inputs.parameters.%s}}' % parameter_name else: @@ -349,16 +413,28 @@ def _group_to_template(self, group, inputs, outputs, dependencies): for param_name, dependent_name in inputs[sub_group.name]: if dependent_name: # The value comes from an upstream sibling. - arguments.append({ - 'name': param_name, - 'value': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name) - }) + # Special handling for recursive subgroup: argument name comes from the existing opsgroup + if is_recursive_subgroup: + for index, input in enumerate(sub_group.inputs): + if param_name == self._pipelineparam_full_name(input): + break + referenced_input = sub_group.recursive_ref.inputs[index] + full_name = self._pipelineparam_full_name(referenced_input) + arguments.append({ + 'name': full_name, + 'value': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name) + }) + else: + arguments.append({ + 'name': param_name, + 'value': '{{tasks.%s.outputs.parameters.%s}}' % (dependent_name, param_name) + }) else: # The value comes from its parent. # Special handling for recursive subgroup: argument name comes from the existing opsgroup if is_recursive_subgroup: for index, input in enumerate(sub_group.inputs): - if param_name == input.name: + if param_name == self._pipelineparam_full_name(input): break referenced_input = sub_group.recursive_ref.inputs[index] full_name = self._pipelineparam_full_name(referenced_input) @@ -385,21 +461,24 @@ def _create_templates(self, pipeline): # Generate core data structures to prepare for argo yaml generation # op_groups: op name -> list of ancestor groups including the current op - # inputs, outputs: group/op names -> list of tuples (param_name, producing_op_name) + # opsgroups: a dictionary of ospgroup.name -> opsgroup + # inputs, outputs: group/op names -> list of tuples (full_param_name, producing_op_name) + # condition_params: recursive_group/op names -> list of pipelineparam # dependencies: group/op name -> list of dependent groups/ops. - # groups: opsgroups # Special Handling for the recursive opsgroup # op_groups also contains the recursive opsgroups # condition_params from _get_condition_params_for_ops also contains the recursive opsgroups # groups does not include the recursive opsgroups + opsgroups = self._get_groups(new_root_group) op_groups = self._get_groups_for_ops(new_root_group) - inputs, outputs = self._get_inputs_outputs(pipeline, new_root_group, op_groups) - dependencies = self._get_dependencies(pipeline, new_root_group, op_groups) - groups = self._get_groups(new_root_group) + opsgroups_groups = self._get_groups_for_opsgroups(new_root_group) + condition_params = self._get_condition_params_for_ops(new_root_group) + inputs, outputs = self._get_inputs_outputs(pipeline, new_root_group, op_groups, opsgroups_groups, condition_params) + dependencies = self._get_dependencies(pipeline, new_root_group, op_groups, opsgroups_groups, opsgroups, condition_params) templates = [] - for g in groups: - templates.append(self._group_to_template(g, inputs, outputs, dependencies)) + for opsgroup in opsgroups.keys(): + templates.append(self._group_to_template(opsgroups[opsgroup], inputs, outputs, dependencies)) for op in pipeline.ops.values(): templates.append(_op_to_template(op)) @@ -538,8 +617,8 @@ def _compile(self, pipeline_func): if op.output is not None: op.output.name = K8sHelper.sanitize_k8s_name(op.output.name) op.output.op_name = K8sHelper.sanitize_k8s_name(op.output.op_name) - if op.dependent_op_names: - op.dependent_op_names = [K8sHelper.sanitize_k8s_name(name) for name in op.dependent_op_names] + if op.dependent_names: + op.dependent_names = [K8sHelper.sanitize_k8s_name(name) for name in op.dependent_names] if op.file_outputs is not None: sanitized_file_outputs = {} for key in op.file_outputs.keys(): diff --git a/sdk/python/kfp/dsl/_component.py b/sdk/python/kfp/dsl/_component.py index 912a81ddb7c..65b8c25ee55 100644 --- a/sdk/python/kfp/dsl/_component.py +++ b/sdk/python/kfp/dsl/_component.py @@ -149,12 +149,7 @@ def _graph_component(*args, **kargs): with graph_ops_group: # Call the function if not graph_ops_group.recursive_ref: - graph_ops_group.outputs = func(*args, **kargs) - if not isinstance(graph_ops_group.outputs, dict): - raise ValueError(func.__name__ + ' needs to return a dictionary of string to PipelineParam.') - for output in graph_ops_group.outputs: - if not (isinstance(output, str) and isinstance(graph_ops_group.outputs[output], PipelineParam)): - raise ValueError(func.__name__ + ' needs to return a dictionary of string to PipelineParam.') + func(*args, **kargs) return graph_ops_group return _graph_component \ No newline at end of file diff --git a/sdk/python/kfp/dsl/_container_op.py b/sdk/python/kfp/dsl/_container_op.py index d0dd56a1854..6634e3dd533 100644 --- a/sdk/python/kfp/dsl/_container_op.py +++ b/sdk/python/kfp/dsl/_container_op.py @@ -757,7 +757,7 @@ def _decorated(*args, **kwargs): # attributes specific to `ContainerOp` self._inputs = [] self.file_outputs = file_outputs - self.dependent_op_names = [] + self.dependent_names = [] self.is_exit_handler = is_exit_handler self._metadata = None @@ -851,7 +851,7 @@ def apply(self, mod_func): def after(self, op): """Specify explicit dependency on another op.""" - self.dependent_op_names.append(op.name) + self.dependent_names.append(op.name) return self def add_volume(self, volume): diff --git a/sdk/python/kfp/dsl/_ops_group.py b/sdk/python/kfp/dsl/_ops_group.py index 72374405623..99078916fea 100644 --- a/sdk/python/kfp/dsl/_ops_group.py +++ b/sdk/python/kfp/dsl/_ops_group.py @@ -36,6 +36,7 @@ def __init__(self, group_type: str, name: str=None): self.ops = list() self.groups = list() self.name = name + self.dependencies = [] # recursive_ref points to the opsgroups with the same name if exists. self.recursive_ref = None @@ -80,6 +81,11 @@ def __enter__(self): def __exit__(self, *args): _pipeline.Pipeline.get_default_pipeline().pop_ops_group() + def after(self, dependency): + """Specify explicit dependency on another op.""" + self.dependencies.append(dependency) + return self + class ExitHandler(OpsGroup): """Represents an exit handler that is invoked upon exiting a group of ops. @@ -101,7 +107,7 @@ def __init__(self, exit_op: _container_op.ContainerOp): ValueError is the exit_op is invalid. """ super(ExitHandler, self).__init__('exit_handler') - if exit_op.dependent_op_names: + if exit_op.dependent_names: raise ValueError('exit_op cannot depend on any other ops.') self.exit_op = exit_op @@ -137,9 +143,4 @@ def __init__(self, name): super(Graph, self).__init__(group_type='graph', name=name) self.inputs = [] self.outputs = {} - self.dependencies = [] - - def after(self, dependency): - """Specify explicit dependency on another op.""" - self.dependencies.append(dependency) - return self \ No newline at end of file + self.dependencies = [] \ No newline at end of file diff --git a/sdk/python/tests/compiler/compiler_tests.py b/sdk/python/tests/compiler/compiler_tests.py index e54dc378fa0..7f2f588f684 100644 --- a/sdk/python/tests/compiler/compiler_tests.py +++ b/sdk/python/tests/compiler/compiler_tests.py @@ -290,9 +290,13 @@ def test_py_image_pull_secret(self): """Test pipeline imagepullsecret.""" self._test_py_compile_yaml('imagepullsecret') - def test_py_recursive(self): + def test_py_recursive_do_while(self): """Test pipeline recursive.""" - self._test_py_compile_yaml('recursive') + self._test_py_compile_yaml('recursive_do_while') + + def test_py_recursive_while(self): + """Test pipeline recursive.""" + self._test_py_compile_yaml('recursive_while') def test_type_checking_with_consistent_types(self): """Test type check pipeline parameters against component metadata.""" diff --git a/sdk/python/tests/compiler/testdata/recursive.py b/sdk/python/tests/compiler/testdata/recursive_do_while.py similarity index 94% rename from sdk/python/tests/compiler/testdata/recursive.py rename to sdk/python/tests/compiler/testdata/recursive_do_while.py index c38ccfe8769..69405aa4441 100644 --- a/sdk/python/tests/compiler/testdata/recursive.py +++ b/sdk/python/tests/compiler/testdata/recursive_do_while.py @@ -44,7 +44,6 @@ def flip_component(flip_result): flipA = FlipCoinOp().after(print_flip) with dsl.Condition(flipA.output == 'heads'): flip_component(flipA.output) - return {'flip_result': flipA.output} @dsl.pipeline( name='pipeline flip coin', @@ -52,8 +51,10 @@ def flip_component(flip_result): ) def recursive(): flipA = FlipCoinOp() + flipB = FlipCoinOp() flip_loop = flip_component(flipA.output) - PrintOp('cool, it is over. %s' % flip_loop.outputs['flip_result']) + flip_loop.after(flipB) + PrintOp('cool, it is over. %s' % flipA.output).after(flip_loop) if __name__ == '__main__': import kfp.compiler as compiler diff --git a/sdk/python/tests/compiler/testdata/recursive.yaml b/sdk/python/tests/compiler/testdata/recursive_do_while.yaml similarity index 77% rename from sdk/python/tests/compiler/testdata/recursive.yaml rename to sdk/python/tests/compiler/testdata/recursive_do_while.yaml index 89afcc2aa27..aa574e04e3c 100644 --- a/sdk/python/tests/compiler/testdata/recursive.yaml +++ b/sdk/python/tests/compiler/testdata/recursive_do_while.yaml @@ -13,12 +13,12 @@ spec: - arguments: parameters: - name: flip-output - value: '{{inputs.parameters.flip-2-output}}' + value: '{{inputs.parameters.flip-3-output}}' name: graph-flip-component-1 template: graph-flip-component-1 inputs: parameters: - - name: flip-2-output + - name: flip-3-output name: condition-2 - container: args: @@ -102,21 +102,62 @@ spec: - name: flip-2-output valueFrom: path: /tmp/output + - container: + args: + - python -c "import random; result = 'heads' if random.randint(0,1) == 0 else + 'tails'; print(result)" | tee /tmp/output + command: + - sh + - -c + image: python:alpine3.6 + name: flip-3 + outputs: + artifacts: + - name: mlpipeline-ui-metadata + path: /mlpipeline-ui-metadata.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact + - name: mlpipeline-metrics + path: /mlpipeline-metrics.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact + parameters: + - name: flip-3-output + valueFrom: + path: /tmp/output - dag: tasks: - arguments: parameters: - - name: flip-2-output - value: '{{tasks.flip-2.outputs.parameters.flip-2-output}}' + - name: flip-3-output + value: '{{tasks.flip-3.outputs.parameters.flip-3-output}}' dependencies: - - flip-2 + - flip-3 name: condition-2 template: condition-2 - when: '{{tasks.flip-2.outputs.parameters.flip-2-output}} == heads' + when: '{{tasks.flip-3.outputs.parameters.flip-3-output}} == heads' - dependencies: - print - name: flip-2 - template: flip-2 + name: flip-3 + template: flip-3 - arguments: parameters: - name: flip-output @@ -127,28 +168,27 @@ spec: parameters: - name: flip-output name: graph-flip-component-1 - outputs: - parameters: - - name: flip-2-output - valueFrom: - parameter: '{{tasks.flip-2.outputs.parameters.flip-2-output}}' - dag: tasks: - name: flip template: flip + - name: flip-2 + template: flip-2 - arguments: parameters: - name: flip-output value: '{{tasks.flip.outputs.parameters.flip-output}}' dependencies: - flip + - flip-2 name: graph-flip-component-1 template: graph-flip-component-1 - arguments: parameters: - - name: flip-2-output - value: '{{tasks.graph-flip-component-1.outputs.parameters.flip-2-output}}' + - name: flip-output + value: '{{tasks.flip.outputs.parameters.flip-output}}' dependencies: + - flip - graph-flip-component-1 name: print-2 template: print-2 @@ -193,11 +233,11 @@ spec: - container: command: - echo - - cool, it is over. {{inputs.parameters.flip-2-output}} + - cool, it is over. {{inputs.parameters.flip-output}} image: alpine:3.6 inputs: parameters: - - name: flip-2-output + - name: flip-output name: print-2 outputs: artifacts: diff --git a/sdk/python/tests/compiler/testdata/recursive_while.py b/sdk/python/tests/compiler/testdata/recursive_while.py new file mode 100644 index 00000000000..64e5d1bb8ec --- /dev/null +++ b/sdk/python/tests/compiler/testdata/recursive_while.py @@ -0,0 +1,59 @@ +# Copyright 2019 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import kfp.dsl as dsl + +class FlipCoinOp(dsl.ContainerOp): + """Flip a coin and output heads or tails randomly.""" + + def __init__(self): + super(FlipCoinOp, self).__init__( + name='Flip', + image='python:alpine3.6', + command=['sh', '-c'], + arguments=['python -c "import random; result = \'heads\' if random.randint(0,1) == 0 ' + 'else \'tails\'; print(result)" | tee /tmp/output'], + file_outputs={'output': '/tmp/output'}) + +class PrintOp(dsl.ContainerOp): + """Print a message.""" + + def __init__(self, msg): + super(PrintOp, self).__init__( + name='Print', + image='alpine:3.6', + command=['echo', msg], + ) + +@dsl._component.graph_component +def flip_component(flip_result): + with dsl.Condition(flip_result == 'heads'): + print_flip = PrintOp(flip_result) + flipA = FlipCoinOp().after(print_flip) + flip_component(flipA.output) + +@dsl.pipeline( + name='pipeline flip coin', + description='shows how to use dsl.Condition.' +) +def flipcoin(): + flipA = FlipCoinOp() + flipB = FlipCoinOp() + flip_loop = flip_component(flipA.output) + flip_loop.after(flipB) + PrintOp('cool, it is over. %s' % flipA.output).after(flip_loop) + +if __name__ == '__main__': + import kfp.compiler as compiler + compiler.Compiler().compile(flipcoin, __file__ + '.tar.gz') diff --git a/sdk/python/tests/compiler/testdata/recursive_while.yaml b/sdk/python/tests/compiler/testdata/recursive_while.yaml new file mode 100644 index 00000000000..77b6baa6be8 --- /dev/null +++ b/sdk/python/tests/compiler/testdata/recursive_while.yaml @@ -0,0 +1,273 @@ +apiVersion: argoproj.io/v1alpha1 +kind: Workflow +metadata: + generateName: pipeline-flip-coin- +spec: + arguments: + parameters: [] + entrypoint: pipeline-flip-coin + serviceAccountName: pipeline-runner + templates: + - dag: + tasks: + - arguments: + parameters: + - name: flip-output + value: '{{inputs.parameters.flip-output}}' + dependencies: + - print + name: flip-3 + template: flip-3 + - arguments: + parameters: + - name: flip-output + value: '{{tasks.flip-3.outputs.parameters.flip-3-output}}' + dependencies: + - flip-3 + name: graph-flip-component-1 + template: graph-flip-component-1 + - arguments: + parameters: + - name: flip-output + value: '{{inputs.parameters.flip-output}}' + name: print + template: print + inputs: + parameters: + - name: flip-output + name: condition-2 + - container: + args: + - python -c "import random; result = 'heads' if random.randint(0,1) == 0 else + 'tails'; print(result)" | tee /tmp/output + command: + - sh + - -c + image: python:alpine3.6 + name: flip + outputs: + artifacts: + - name: mlpipeline-ui-metadata + path: /mlpipeline-ui-metadata.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact + - name: mlpipeline-metrics + path: /mlpipeline-metrics.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact + parameters: + - name: flip-output + valueFrom: + path: /tmp/output + - container: + args: + - python -c "import random; result = 'heads' if random.randint(0,1) == 0 else + 'tails'; print(result)" | tee /tmp/output + command: + - sh + - -c + image: python:alpine3.6 + name: flip-2 + outputs: + artifacts: + - name: mlpipeline-ui-metadata + path: /mlpipeline-ui-metadata.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact + - name: mlpipeline-metrics + path: /mlpipeline-metrics.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact + parameters: + - name: flip-2-output + valueFrom: + path: /tmp/output + - container: + args: + - python -c "import random; result = 'heads' if random.randint(0,1) == 0 else + 'tails'; print(result)" | tee /tmp/output + command: + - sh + - -c + image: python:alpine3.6 + name: flip-3 + outputs: + artifacts: + - name: mlpipeline-ui-metadata + path: /mlpipeline-ui-metadata.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact + - name: mlpipeline-metrics + path: /mlpipeline-metrics.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact + parameters: + - name: flip-3-output + valueFrom: + path: /tmp/output + - dag: + tasks: + - arguments: + parameters: + - name: flip-output + value: '{{inputs.parameters.flip-output}}' + name: condition-2 + template: condition-2 + when: '{{inputs.parameters.flip-output}} == heads' + inputs: + parameters: + - name: flip-output + name: graph-flip-component-1 + - dag: + tasks: + - name: flip + template: flip + - name: flip-2 + template: flip-2 + - arguments: + parameters: + - name: flip-output + value: '{{tasks.flip.outputs.parameters.flip-output}}' + dependencies: + - flip + - flip-2 + name: graph-flip-component-1 + template: graph-flip-component-1 + - arguments: + parameters: + - name: flip-output + value: '{{tasks.flip.outputs.parameters.flip-output}}' + dependencies: + - flip + - graph-flip-component-1 + name: print-2 + template: print-2 + name: pipeline-flip-coin + - container: + command: + - echo + - '{{inputs.parameters.flip-output}}' + image: alpine:3.6 + inputs: + parameters: + - name: flip-output + name: print + outputs: + artifacts: + - name: mlpipeline-ui-metadata + path: /mlpipeline-ui-metadata.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact + - name: mlpipeline-metrics + path: /mlpipeline-metrics.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact + - container: + command: + - echo + - cool, it is over. {{inputs.parameters.flip-output}} + image: alpine:3.6 + inputs: + parameters: + - name: flip-output + name: print-2 + outputs: + artifacts: + - name: mlpipeline-ui-metadata + path: /mlpipeline-ui-metadata.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-ui-metadata.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact + - name: mlpipeline-metrics + path: /mlpipeline-metrics.json + s3: + accessKeySecret: + key: accesskey + name: mlpipeline-minio-artifact + bucket: mlpipeline + endpoint: minio-service.kubeflow:9000 + insecure: true + key: runs/{{workflow.uid}}/{{pod.name}}/mlpipeline-metrics.tgz + secretKeySecret: + key: secretkey + name: mlpipeline-minio-artifact diff --git a/sdk/python/tests/dsl/component_tests.py b/sdk/python/tests/dsl/component_tests.py index 6a10f03d08d..89974cb6c0d 100644 --- a/sdk/python/tests/dsl/component_tests.py +++ b/sdk/python/tests/dsl/component_tests.py @@ -433,7 +433,6 @@ def test_graphcomponent_basic(self): def flip_component(flip_result): with dsl.Condition(flip_result == 'heads'): flip_component(flip_result) - return {'flip_result': flip_result} with Pipeline('pipeline') as p: param = PipelineParam(name='param') @@ -447,6 +446,3 @@ def flip_component(flip_result): self.assertTrue(recursive_group.recursive_ref is not None) self.assertEqual(1, len(recursive_group.inputs)) self.assertEqual('param', recursive_group.inputs[0].name) - original_group = p.groups[0].groups[0] - self.assertTrue('flip_result' in original_group.outputs) - self.assertEqual('param', original_group.outputs['flip_result']) diff --git a/sdk/python/tests/dsl/container_op_tests.py b/sdk/python/tests/dsl/container_op_tests.py index 840b54c2709..25fea3984c4 100644 --- a/sdk/python/tests/dsl/container_op_tests.py +++ b/sdk/python/tests/dsl/container_op_tests.py @@ -49,7 +49,7 @@ def test_after_op(self): op1 = ContainerOp(name='op1', image='image') op2 = ContainerOp(name='op2', image='image') op2.after(op1) - self.assertCountEqual(op2.dependent_op_names, [op1.name]) + self.assertCountEqual(op2.dependent_names, [op1.name]) def test_deprecation_warnings(self):