diff --git a/backend/src/v2/driver/driver.go b/backend/src/v2/driver/driver.go index 87eb211892f..5916ebcd7a6 100644 --- a/backend/src/v2/driver/driver.go +++ b/backend/src/v2/driver/driver.go @@ -414,6 +414,7 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E executorInput := &pipelinespec.ExecutorInput{ Inputs: inputs, } + glog.Infof("executorInput value: %+v", executorInput) execution = &Execution{ExecutorInput: executorInput} condition := opts.Task.GetTriggerPolicy().GetCondition() if condition != "" { @@ -436,14 +437,37 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E return execution, fmt.Errorf("ArtifactIterator is not implemented") } isIterator := opts.Task.GetParameterIterator() != nil && opts.IterationIndex < 0 + // Fan out iterations if execution.WillTrigger() && isIterator { iterator := opts.Task.GetParameterIterator() - value, ok := executorInput.GetInputs().GetParameterValues()[iterator.GetItems().GetInputParameter()] report := func(err error) error { return fmt.Errorf("iterating on item input %q failed: %w", iterator.GetItemInput(), err) } - if !ok { - return execution, report(fmt.Errorf("cannot find input parameter")) + // Check the items type of parameterIterator: + // It can be "inputParameter" or "Raw" + var value *structpb.Value + switch iterator.GetItems().GetKind().(type) { + case *pipelinespec.ParameterIteratorSpec_ItemsSpec_InputParameter: + var ok bool + value, ok = executorInput.GetInputs().GetParameterValues()[iterator.GetItems().GetInputParameter()] + if !ok { + return execution, report(fmt.Errorf("cannot find input parameter")) + } + case *pipelinespec.ParameterIteratorSpec_ItemsSpec_Raw: + value_raw := iterator.GetItems().GetRaw() + var unmarshalled_raw interface{} + err = json.Unmarshal([]byte(value_raw), &unmarshalled_raw) + if err != nil { + return execution, fmt.Errorf("error unmarshall raw string: %q", err) + } + value, err = structpb.NewValue(unmarshalled_raw) + if err != nil { + return execution, fmt.Errorf("error converting unmarshalled raw string into protobuf Value type: %q", err) + } + // Add the raw input to the executor input + execution.ExecutorInput.Inputs.ParameterValues[iterator.GetItemInput()] = value + default: + return execution, fmt.Errorf("cannot find parameter iterator") } items, err := getItems(value) if err != nil { @@ -724,7 +748,16 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int, case task.GetArtifactIterator() != nil: return nil, fmt.Errorf("artifact iterator not implemented yet") case task.GetParameterIterator() != nil: - itemsInput := task.GetParameterIterator().GetItems().GetInputParameter() + var itemsInput string + if task.GetParameterIterator().GetItems().GetInputParameter() != "" { + // input comes from outside the component + itemsInput = task.GetParameterIterator().GetItems().GetInputParameter() + } else if task.GetParameterIterator().GetItemInput() != "" { + // input comes from static input + itemsInput = task.GetParameterIterator().GetItemInput() + } else { + return nil, fmt.Errorf("cannot retrieve parameter iterator.") + } items, err := getItems(inputs.ParameterValues[itemsInput]) if err != nil { return nil, err diff --git a/samples/core/loop_static/loop_static_test.py b/samples/core/loop_static/loop_static_test.py index c11ad9312e2..909d6261f9d 100644 --- a/samples/core/loop_static/loop_static_test.py +++ b/samples/core/loop_static/loop_static_test.py @@ -18,14 +18,14 @@ import kfp_server_api from .loop_static import my_pipeline from .loop_static_v2 import my_pipeline as my_pipeline_v2 -from kfp.samples.test.utils import KfpTask, debug_verify, run_pipeline_func, TestCase +from kfp.samples.test.utils import KfpTask, run_pipeline_func, TestCase def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, tasks: dict[str, KfpTask], **kwargs): t.assertEqual(run.status, 'Succeeded') # assert DAG structure - t.assertCountEqual(['print-op', 'for-loop-1'], tasks.keys()) + t.assertCountEqual(['print-op', 'for-loop-2'], tasks.keys()) # assert all iteration parameters t.assertCountEqual( [{ @@ -37,14 +37,14 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun, }], [ x.inputs - .parameters['pipelinechannel--static_loop_arguments-loop-item'] - for x in tasks['for-loop-1'].children.values() + .parameters['pipelinechannel--loop-item-param-1'] + for x in tasks['for-loop-2'].children.values() ], ) # assert all iteration outputs t.assertCountEqual(['12', '1020'], [ x.children['print-op-2'].outputs.parameters['Output'] - for x in tasks['for-loop-1'].children.values() + for x in tasks['for-loop-2'].children.values() ]) diff --git a/samples/core/loop_static/loop_static_v2.py b/samples/core/loop_static/loop_static_v2.py index a5c1eafe9c0..62f92b6e814 100644 --- a/samples/core/loop_static/loop_static_v2.py +++ b/samples/core/loop_static/loop_static_v2.py @@ -19,15 +19,12 @@ def concat_op(a: str, b: str) -> str: return a + b -_DEFAULT_LOOP_ARGUMENTS = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}] - - @dsl.pipeline(name='pipeline-with-loop-static') def my_pipeline( - static_loop_arguments: List[dict] = _DEFAULT_LOOP_ARGUMENTS, greeting: str = 'this is a test for looping through parameters', ): print_task = print_op(text=greeting) + static_loop_arguments = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}] with dsl.ParallelFor(static_loop_arguments) as item: concat_task = concat_op(a=item.a, b=item.b)