Skip to content

Commit

Permalink
fix(sdk): Prevents dsl.ParallelFor over single parameter from compili…
Browse files Browse the repository at this point in the history
…ng. (kubeflow#10494)

* fix(sdk): Prevents dsl.ParallelFor over single paramter from compiling.

* fix(sdk): Prevents dsl.ParallelFor over single paramter from compiling.

* update PR number in release notes

* formatting

* Add compiler_test.py test for single param compile failure

* Update some docstrings and add todo

* formatting

* Update sdk/python/kfp/compiler/compiler_test.py

Co-authored-by: Connor McCarthy <[email protected]>

* Update sdk/python/kfp/compiler/compiler_test.py

Co-authored-by: Connor McCarthy <[email protected]>

* Update sdk/python/kfp/dsl/for_loop.py

Co-authored-by: Connor McCarthy <[email protected]>

* Use print_and_return and other small changes

* typo

* typo

---------

Co-authored-by: Connor McCarthy <[email protected]>
  • Loading branch information
2 people authored and petethegreat committed Mar 27, 2024
1 parent 3eb3bf6 commit 3899ebb
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 6 deletions.
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
## Deprecations

## Bug fixes and other changes
* Throw compilation error when trying to iterate over a single parameter with ParallelFor [\#10494](https://github.com/kubeflow/pipelines/pull/10494)

## Documentation updates

Expand Down
13 changes: 13 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,19 @@ def my_pipeline(text: bool):
with self.assertRaises(KeyError):
for_loop_4['iteratorPolicy']

def test_cannot_compile_parallel_for_with_single_param(self):

with self.assertRaisesRegex(
ValueError,
r'Cannot iterate over a single parameter using `dsl\.ParallelFor`\. Expected a list of parameters as argument to `items`\.'
):

@dsl.pipeline
def my_pipeline():
single_param_task = print_and_return(text='string')
with dsl.ParallelFor(items=single_param_task.output) as item:
print_and_return(text=item)

def test_pipeline_in_pipeline(self):

@dsl.pipeline(name='graph-component')
Expand Down
21 changes: 17 additions & 4 deletions sdk/python/kfp/dsl/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from typing import Any, Dict, List, Optional, Union

from kfp.dsl import pipeline_channel
from kfp.dsl.types import type_annotations
from kfp.dsl.types import type_utils

ItemList = List[Union[int, float, str, Dict[str, Any]]]

Expand Down Expand Up @@ -124,7 +126,7 @@ def __init__(
Python variable name.
name_code: A unique code used to identify these loop arguments.
Should match the code for the ParallelFor ops_group which created
these LoopArguments. This prevents parameter name collisions.
these LoopParameterArguments. This prevents parameter name collisions.
name_override: The override name for PipelineParameterChannel.
**kwargs: Any other keyword arguments passed down to PipelineParameterChannel.
"""
Expand Down Expand Up @@ -166,7 +168,7 @@ def __init__(

def __getattr__(self, name: str):
# this is being overridden so that we can access subvariables of the
# LoopArgument (i.e.: item.a) without knowing the subvariable names ahead
# LoopParameterArgument (i.e.: item.a) without knowing the subvariable names ahead
# of time.

return self._referenced_subvars.setdefault(
Expand All @@ -188,6 +190,17 @@ def from_pipeline_channel(
compilation progress in cases of unknown or missing type
information.
"""
# if channel is a LoopArgumentVariable, current system cannot check if
# nested items are lists.
if not isinstance(channel, LoopArgumentVariable):
type_name = type_annotations.get_short_type_name(
channel.channel_type)
parameter_type = type_utils.PARAMETER_TYPES_MAPPING[
type_name.lower()]
if parameter_type != type_utils.LIST:
raise ValueError(
'Cannot iterate over a single parameter using `dsl.ParallelFor`. Expected a list of parameters as argument to `items`.'
)
return LoopParameterArgument(
items=channel,
name_override=channel.name + '-' + LOOP_ITEM_NAME_BASE,
Expand Down Expand Up @@ -297,7 +310,7 @@ class LoopArgumentVariable(pipeline_channel.PipelineParameterChannel):
Then there's one LoopArgumentVariable for 'a' and another for 'b'.
Attributes:
loop_argument: The original LoopArgument object this subvariable is
loop_argument: The original LoopParameterArgument object this subvariable is
attached to.
subvar_name: The subvariable name.
"""
Expand Down Expand Up @@ -327,7 +340,7 @@ def __init__(

self.subvar_name = subvar_name
self.loop_argument = loop_argument
# Handle potential channel_type extraction errors from LoopArgument by defaulting to 'String'. This maintains compilation progress.
# Handle potential channel_type extraction errors from LoopParameterArgument by defaulting to 'String'. This maintains compilation progress.
super().__init__(
name=self._get_name_override(
loop_arg_name=loop_argument.name,
Expand Down
18 changes: 18 additions & 0 deletions sdk/python/kfp/dsl/for_loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,24 @@ def test_loop_parameter_argument_from_pipeline_channel(
self.assertEqual(loop_argument.items_or_pipeline_channel, channel)
self.assertEqual(str(loop_argument), expected_serialization_value)

@parameterized.parameters(
{
'channel':
pipeline_channel.PipelineParameterChannel(
name='param1',
channel_type='String',
task_name='task1',
),
},)
def test_loop_parameter_argument_from_single_pipeline_channel_raises_error(
self, channel):
with self.assertRaisesRegex(
ValueError,
r'Cannot iterate over a single parameter using `dsl\.ParallelFor`\. Expected a list of parameters as argument to `items`\.'
):
loop_argument = for_loop.LoopParameterArgument.from_pipeline_channel(
channel)

@parameterized.parameters(
{
'channel':
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/kfp/dsl/types/type_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ class TestTypeChecking(parameterized.TestCase):
loop_argument=for_loop.LoopParameterArgument
.from_pipeline_channel(
pipeline_channel.create_pipeline_channel(
'Output-loop-item', 'String',
'Output-loop-item', 'List[str]',
'list-dict-without-type-maker-5')),
subvar_name='a'),
'parameter_input_spec':
Expand All @@ -732,7 +732,7 @@ class TestTypeChecking(parameterized.TestCase):
'argument_value':
for_loop.LoopParameterArgument.from_pipeline_channel(
pipeline_channel.create_pipeline_channel(
'Output-loop-item', 'String',
'Output-loop-item', 'List[int]',
'list-dict-without-type-maker-5')),
'parameter_input_spec':
structures.InputSpec('Integer'),
Expand Down

0 comments on commit 3899ebb

Please sign in to comment.