Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sdk): Prevents dsl.ParallelFor over single parameter from compiling. #10494

Merged
merged 15 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
## Deprecations

## Bug fixes and other changes
* Throw compilation error when trying to iterate over a single parameter with ParallelFor [\#10494](https://github.com/kubeflow/pipelines/pull/10494)

## Documentation updates

Expand Down
13 changes: 13 additions & 0 deletions sdk/python/kfp/compiler/compiler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,6 +825,19 @@ def my_pipeline(text: bool):
with self.assertRaises(KeyError):
for_loop_4['iteratorPolicy']

def test_cannot_compile_parallel_for_with_single_param(self):

with self.assertRaisesRegex(
ValueError,
r'Cannot iterate over a single parameter using `dsl\.ParallelFor`\. Expected a list of parameters as argument to `items`\.'
):

@dsl.pipeline
def my_pipeline():
single_param_task = print_and_return(text='string')
with dsl.ParallelFor(items=single_param_task.output) as item:
print_and_return(text=item)

def test_pipeline_in_pipeline(self):

@dsl.pipeline(name='graph-component')
Expand Down
21 changes: 17 additions & 4 deletions sdk/python/kfp/dsl/for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
from typing import Any, Dict, List, Optional, Union

from kfp.dsl import pipeline_channel
from kfp.dsl.types import type_annotations
from kfp.dsl.types import type_utils

ItemList = List[Union[int, float, str, Dict[str, Any]]]

Expand Down Expand Up @@ -124,7 +126,7 @@ def __init__(
Python variable name.
name_code: A unique code used to identify these loop arguments.
Should match the code for the ParallelFor ops_group which created
these LoopArguments. This prevents parameter name collisions.
these LoopParameterArguments. This prevents parameter name collisions.
name_override: The override name for PipelineParameterChannel.
**kwargs: Any other keyword arguments passed down to PipelineParameterChannel.
"""
Expand Down Expand Up @@ -166,7 +168,7 @@ def __init__(

def __getattr__(self, name: str):
# this is being overridden so that we can access subvariables of the
# LoopArgument (i.e.: item.a) without knowing the subvariable names ahead
# LoopParameterArgument (i.e.: item.a) without knowing the subvariable names ahead
# of time.

return self._referenced_subvars.setdefault(
Expand All @@ -188,6 +190,17 @@ def from_pipeline_channel(
compilation progress in cases of unknown or missing type
information.
"""
# if channel is a LoopArgumentVariable, current system cannot check if
# nested items are lists.
if not isinstance(channel, LoopArgumentVariable):
type_name = type_annotations.get_short_type_name(
channel.channel_type)
parameter_type = type_utils.PARAMETER_TYPES_MAPPING[
type_name.lower()]
if parameter_type != type_utils.LIST:
raise ValueError(
'Cannot iterate over a single parameter using `dsl.ParallelFor`. Expected a list of parameters as argument to `items`.'
)
return LoopParameterArgument(
items=channel,
name_override=channel.name + '-' + LOOP_ITEM_NAME_BASE,
Expand Down Expand Up @@ -297,7 +310,7 @@ class LoopArgumentVariable(pipeline_channel.PipelineParameterChannel):
Then there's one LoopArgumentVariable for 'a' and another for 'b'.

Attributes:
loop_argument: The original LoopArgument object this subvariable is
loop_argument: The original LoopParameterArgument object this subvariable is
attached to.
subvar_name: The subvariable name.
"""
Expand Down Expand Up @@ -327,7 +340,7 @@ def __init__(

self.subvar_name = subvar_name
self.loop_argument = loop_argument
# Handle potential channel_type extraction errors from LoopArgument by defaulting to 'String'. This maintains compilation progress.
# Handle potential channel_type extraction errors from LoopParameterArgument by defaulting to 'String'. This maintains compilation progress.
super().__init__(
name=self._get_name_override(
loop_arg_name=loop_argument.name,
Expand Down
18 changes: 18 additions & 0 deletions sdk/python/kfp/dsl/for_loop_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,24 @@ def test_loop_parameter_argument_from_pipeline_channel(
self.assertEqual(loop_argument.items_or_pipeline_channel, channel)
self.assertEqual(str(loop_argument), expected_serialization_value)

@parameterized.parameters(
{
'channel':
pipeline_channel.PipelineParameterChannel(
name='param1',
channel_type='String',
task_name='task1',
),
},)
def test_loop_parameter_argument_from_single_pipeline_channel_raises_error(
self, channel):
with self.assertRaisesRegex(
ValueError,
r'Cannot iterate over a single parameter using `dsl\.ParallelFor`\. Expected a list of parameters as argument to `items`\.'
):
loop_argument = for_loop.LoopParameterArgument.from_pipeline_channel(
channel)

KevinGrantLee marked this conversation as resolved.
Show resolved Hide resolved
@parameterized.parameters(
{
'channel':
Expand Down
4 changes: 2 additions & 2 deletions sdk/python/kfp/dsl/types/type_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ class TestTypeChecking(parameterized.TestCase):
loop_argument=for_loop.LoopParameterArgument
.from_pipeline_channel(
pipeline_channel.create_pipeline_channel(
'Output-loop-item', 'String',
'Output-loop-item', 'List[str]',
'list-dict-without-type-maker-5')),
subvar_name='a'),
'parameter_input_spec':
Expand All @@ -732,7 +732,7 @@ class TestTypeChecking(parameterized.TestCase):
'argument_value':
for_loop.LoopParameterArgument.from_pipeline_channel(
pipeline_channel.create_pipeline_channel(
'Output-loop-item', 'String',
'Output-loop-item', 'List[int]',
'list-dict-without-type-maker-5')),
'parameter_input_spec':
structures.InputSpec('Integer'),
Expand Down