Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(backend): fix simple loop bug #7578

Merged
merged 22 commits into from
May 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 37 additions & 4 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
executorInput := &pipelinespec.ExecutorInput{
Inputs: inputs,
}
glog.Infof("executorInput value: %+v", executorInput)
execution = &Execution{ExecutorInput: executorInput}
condition := opts.Task.GetTriggerPolicy().GetCondition()
if condition != "" {
Expand All @@ -436,14 +437,37 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
return execution, fmt.Errorf("ArtifactIterator is not implemented")
}
isIterator := opts.Task.GetParameterIterator() != nil && opts.IterationIndex < 0
// Fan out iterations
if execution.WillTrigger() && isIterator {
iterator := opts.Task.GetParameterIterator()
value, ok := executorInput.GetInputs().GetParameterValues()[iterator.GetItems().GetInputParameter()]
report := func(err error) error {
return fmt.Errorf("iterating on item input %q failed: %w", iterator.GetItemInput(), err)
}
if !ok {
return execution, report(fmt.Errorf("cannot find input parameter"))
// Check the items type of parameterIterator:
// It can be "inputParameter" or "Raw"
var value *structpb.Value
switch iterator.GetItems().GetKind().(type) {
case *pipelinespec.ParameterIteratorSpec_ItemsSpec_InputParameter:
var ok bool
value, ok = executorInput.GetInputs().GetParameterValues()[iterator.GetItems().GetInputParameter()]
if !ok {
return execution, report(fmt.Errorf("cannot find input parameter"))
}
case *pipelinespec.ParameterIteratorSpec_ItemsSpec_Raw:
value_raw := iterator.GetItems().GetRaw()
var unmarshalled_raw interface{}
err = json.Unmarshal([]byte(value_raw), &unmarshalled_raw)
if err != nil {
return execution, fmt.Errorf("error unmarshall raw string: %q", err)
}
value, err = structpb.NewValue(unmarshalled_raw)
if err != nil {
return execution, fmt.Errorf("error converting unmarshalled raw string into protobuf Value type: %q", err)
}
// Add the raw input to the executor input
execution.ExecutorInput.Inputs.ParameterValues[iterator.GetItemInput()] = value
default:
return execution, fmt.Errorf("cannot find parameter iterator")
}
items, err := getItems(value)
if err != nil {
Expand Down Expand Up @@ -724,7 +748,16 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
case task.GetArtifactIterator() != nil:
return nil, fmt.Errorf("artifact iterator not implemented yet")
case task.GetParameterIterator() != nil:
itemsInput := task.GetParameterIterator().GetItems().GetInputParameter()
var itemsInput string
if task.GetParameterIterator().GetItems().GetInputParameter() != "" {
// input comes from outside the component
itemsInput = task.GetParameterIterator().GetItems().GetInputParameter()
} else if task.GetParameterIterator().GetItemInput() != "" {
// input comes from static input
itemsInput = task.GetParameterIterator().GetItemInput()
} else {
return nil, fmt.Errorf("cannot retrieve parameter iterator.")
}
items, err := getItems(inputs.ParameterValues[itemsInput])
if err != nil {
return nil, err
Expand Down
10 changes: 5 additions & 5 deletions samples/core/loop_static/loop_static_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import kfp_server_api
from .loop_static import my_pipeline
from .loop_static_v2 import my_pipeline as my_pipeline_v2
from kfp.samples.test.utils import KfpTask, debug_verify, run_pipeline_func, TestCase
from kfp.samples.test.utils import KfpTask, run_pipeline_func, TestCase


def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun,
tasks: dict[str, KfpTask], **kwargs):
t.assertEqual(run.status, 'Succeeded')
# assert DAG structure
t.assertCountEqual(['print-op', 'for-loop-1'], tasks.keys())
t.assertCountEqual(['print-op', 'for-loop-2'], tasks.keys())
# assert all iteration parameters
t.assertCountEqual(
[{
Expand All @@ -37,14 +37,14 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun,
}],
[
x.inputs
.parameters['pipelinechannel--static_loop_arguments-loop-item']
for x in tasks['for-loop-1'].children.values()
.parameters['pipelinechannel--loop-item-param-1']
for x in tasks['for-loop-2'].children.values()
],
)
# assert all iteration outputs
t.assertCountEqual(['12', '1020'], [
x.children['print-op-2'].outputs.parameters['Output']
for x in tasks['for-loop-1'].children.values()
for x in tasks['for-loop-2'].children.values()
])


Expand Down
5 changes: 1 addition & 4 deletions samples/core/loop_static/loop_static_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,12 @@ def concat_op(a: str, b: str) -> str:
return a + b


_DEFAULT_LOOP_ARGUMENTS = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}]


@dsl.pipeline(name='pipeline-with-loop-static')
def my_pipeline(
static_loop_arguments: List[dict] = _DEFAULT_LOOP_ARGUMENTS,
greeting: str = 'this is a test for looping through parameters',
):
print_task = print_op(text=greeting)
static_loop_arguments = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}]

with dsl.ParallelFor(static_loop_arguments) as item:
concat_task = concat_op(a=item.a, b=item.b)
Expand Down