Skip to content

Commit

Permalink
fix(backend): fix simple loop bug (kubeflow#7578)
Browse files Browse the repository at this point in the history
* support IR YAML format in API

* Check the error message and return false if it is not nil

* update error message

* fixed simple loop but need cleaning up

* Deleted debug logs

* remove logs and fix some format

* fix static_loop_arguments

* change the driver image 

change the driver image back to the kfp container registry.

* change variable declaration

* remove logs

* remove log

* move `ok` definition

* change test file for debug purpose

* change test for debug purpose

* update sample test for static loop

* update test file, remove code for debug
  • Loading branch information
Linchin authored and abaland committed May 29, 2022
1 parent 757fe88 commit 206515a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 13 deletions.
41 changes: 37 additions & 4 deletions backend/src/v2/driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
executorInput := &pipelinespec.ExecutorInput{
Inputs: inputs,
}
glog.Infof("executorInput value: %+v", executorInput)
execution = &Execution{ExecutorInput: executorInput}
condition := opts.Task.GetTriggerPolicy().GetCondition()
if condition != "" {
Expand All @@ -436,14 +437,37 @@ func DAG(ctx context.Context, opts Options, mlmd *metadata.Client) (execution *E
return execution, fmt.Errorf("ArtifactIterator is not implemented")
}
isIterator := opts.Task.GetParameterIterator() != nil && opts.IterationIndex < 0
// Fan out iterations
if execution.WillTrigger() && isIterator {
iterator := opts.Task.GetParameterIterator()
value, ok := executorInput.GetInputs().GetParameterValues()[iterator.GetItems().GetInputParameter()]
report := func(err error) error {
return fmt.Errorf("iterating on item input %q failed: %w", iterator.GetItemInput(), err)
}
if !ok {
return execution, report(fmt.Errorf("cannot find input parameter"))
// Check the items type of parameterIterator:
// It can be "inputParameter" or "Raw"
var value *structpb.Value
switch iterator.GetItems().GetKind().(type) {
case *pipelinespec.ParameterIteratorSpec_ItemsSpec_InputParameter:
var ok bool
value, ok = executorInput.GetInputs().GetParameterValues()[iterator.GetItems().GetInputParameter()]
if !ok {
return execution, report(fmt.Errorf("cannot find input parameter"))
}
case *pipelinespec.ParameterIteratorSpec_ItemsSpec_Raw:
value_raw := iterator.GetItems().GetRaw()
var unmarshalled_raw interface{}
err = json.Unmarshal([]byte(value_raw), &unmarshalled_raw)
if err != nil {
return execution, fmt.Errorf("error unmarshall raw string: %q", err)
}
value, err = structpb.NewValue(unmarshalled_raw)
if err != nil {
return execution, fmt.Errorf("error converting unmarshalled raw string into protobuf Value type: %q", err)
}
// Add the raw input to the executor input
execution.ExecutorInput.Inputs.ParameterValues[iterator.GetItemInput()] = value
default:
return execution, fmt.Errorf("cannot find parameter iterator")
}
items, err := getItems(value)
if err != nil {
Expand Down Expand Up @@ -724,7 +748,16 @@ func resolveInputs(ctx context.Context, dag *metadata.DAG, iterationIndex *int,
case task.GetArtifactIterator() != nil:
return nil, fmt.Errorf("artifact iterator not implemented yet")
case task.GetParameterIterator() != nil:
itemsInput := task.GetParameterIterator().GetItems().GetInputParameter()
var itemsInput string
if task.GetParameterIterator().GetItems().GetInputParameter() != "" {
// input comes from outside the component
itemsInput = task.GetParameterIterator().GetItems().GetInputParameter()
} else if task.GetParameterIterator().GetItemInput() != "" {
// input comes from static input
itemsInput = task.GetParameterIterator().GetItemInput()
} else {
return nil, fmt.Errorf("cannot retrieve parameter iterator.")
}
items, err := getItems(inputs.ParameterValues[itemsInput])
if err != nil {
return nil, err
Expand Down
10 changes: 5 additions & 5 deletions samples/core/loop_static/loop_static_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,14 @@
import kfp_server_api
from .loop_static import my_pipeline
from .loop_static_v2 import my_pipeline as my_pipeline_v2
from kfp.samples.test.utils import KfpTask, debug_verify, run_pipeline_func, TestCase
from kfp.samples.test.utils import KfpTask, run_pipeline_func, TestCase


def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun,
tasks: dict[str, KfpTask], **kwargs):
t.assertEqual(run.status, 'Succeeded')
# assert DAG structure
t.assertCountEqual(['print-op', 'for-loop-1'], tasks.keys())
t.assertCountEqual(['print-op', 'for-loop-2'], tasks.keys())
# assert all iteration parameters
t.assertCountEqual(
[{
Expand All @@ -37,14 +37,14 @@ def verify(t: unittest.TestCase, run: kfp_server_api.ApiRun,
}],
[
x.inputs
.parameters['pipelinechannel--static_loop_arguments-loop-item']
for x in tasks['for-loop-1'].children.values()
.parameters['pipelinechannel--loop-item-param-1']
for x in tasks['for-loop-2'].children.values()
],
)
# assert all iteration outputs
t.assertCountEqual(['12', '1020'], [
x.children['print-op-2'].outputs.parameters['Output']
for x in tasks['for-loop-1'].children.values()
for x in tasks['for-loop-2'].children.values()
])


Expand Down
5 changes: 1 addition & 4 deletions samples/core/loop_static/loop_static_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,12 @@ def concat_op(a: str, b: str) -> str:
return a + b


_DEFAULT_LOOP_ARGUMENTS = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}]


@dsl.pipeline(name='pipeline-with-loop-static')
def my_pipeline(
static_loop_arguments: List[dict] = _DEFAULT_LOOP_ARGUMENTS,
greeting: str = 'this is a test for looping through parameters',
):
print_task = print_op(text=greeting)
static_loop_arguments = [{'a': '1', 'b': '2'}, {'a': '10', 'b': '20'}]

with dsl.ParallelFor(static_loop_arguments) as item:
concat_task = concat_op(a=item.a, b=item.b)
Expand Down

0 comments on commit 206515a

Please sign in to comment.