Skip to content

Commit

Permalink
SDK - Tests - Improved tests for serializing lists containing objects (
Browse files Browse the repository at this point in the history
…#3326)

Added test_fail_on_handling_list_arguments_containing_python_objects
Added test_handling_list_arguments_containing_serializable_python_objects
Moved test_handling_list_arguments_containing_pipelineparam to component_bridge_tests
  • Loading branch information
Ark-kun authored Mar 24, 2020
1 parent 9873d24 commit 7ee500f
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 17 deletions.
64 changes: 48 additions & 16 deletions sdk/python/tests/components/test_python_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from pathlib import Path
from typing import Callable, NamedTuple, Sequence

import kfp
import kfp.components as comp
from kfp.components import InputPath, InputTextFile, InputBinaryFile, OutputPath, OutputTextFile, OutputBinaryFile
from kfp.components.structures import InputSpec, OutputSpec
Expand Down Expand Up @@ -513,22 +512,57 @@ def assert_values_are_same(
])


def test_handling_list_arguments_containing_pipelineparam(self):
'''Checks that lists containing PipelineParam can be properly serialized'''
def consume_list(list_param: list) -> int:
def test_fail_on_handling_list_arguments_containing_python_objects(self):
'''Checks that lists containing python objects not having .to_struct() raise error during serialization.'''

class MyClass:
pass

import kfp
task_factory = comp.func_to_container_op(consume_list)
task = task_factory([1, 2, 3, kfp.dsl.PipelineParam("aaa"), 4, 5, 6])
resolved_cmd = _resolve_command_line_and_paths(
task.component_ref.spec,
task.arguments,
)
full_command_line = resolved_cmd.command + resolved_cmd.args
for arg in full_command_line:
self.assertNotIn('PipelineParam', arg)
def consume_list(
list_param: list,
) -> int:
return 1

def consume_dict(
dict_param: dict,
) -> int:
return 1

list_op = comp.create_component_from_func(consume_list)
dict_op = comp.create_component_from_func(consume_dict)

with self.assertRaises(Exception):
list_op([1, MyClass(), 3])

with self.assertRaises(Exception):
dict_op({'k1': MyClass()})

def test_handling_list_arguments_containing_serializable_python_objects(self):
'''Checks that lists containing python objects with .to_struct() can be properly serialized.'''

class MyClass:
def to_struct(self):
return {'foo': [7, 42]}

def assert_values_are_correct(
list_param: list,
dict_param: dict,
) -> int:
import unittest
unittest.TestCase().assertEqual(list_param, [1, {'foo': [7, 42]}, 3])
unittest.TestCase().assertEqual(dict_param, {'k1': {'foo': [7, 42]}})
return 1

task_factory = comp.create_component_from_func(assert_values_are_correct)

self.helper_test_component_using_local_call(
task_factory,
arguments=dict(
list_param=[1, MyClass(), 3],
dict_param={'k1': MyClass()},
),
expected_output_values={'Output': '1'},
)

def test_handling_base64_pickle_arguments(self):
def assert_values_are_same(
Expand Down Expand Up @@ -831,8 +865,6 @@ def test_packages_to_install_feature(self):


def test_end_to_end_python_component_pipeline(self):
import kfp.components as comp

#Defining the Python function
def add(a: float, b: float) -> float:
'''Returns sum of two arguments'''
Expand Down
15 changes: 14 additions & 1 deletion sdk/python/tests/dsl/component_bridge_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import unittest
import kfp
from pathlib import Path
from kfp.components import load_component_from_text
from kfp.components import load_component_from_text, create_component_from_func
from kfp.dsl.types import InconsistentTypeException


Expand Down Expand Up @@ -181,3 +181,16 @@ def calc_pipeline(
#Compiling the pipleine:
pipeline_filename = str(Path(temp_dir_name).joinpath(calc_pipeline.__name__ + '.pipeline.tar.gz'))
kfp.compiler.Compiler().compile(calc_pipeline, pipeline_filename)

def test_handling_list_arguments_containing_pipelineparam(self):
'''Checks that lists containing PipelineParam can be properly serialized'''
def consume_list(list_param: list) -> int:
pass

import kfp
task_factory = create_component_from_func(consume_list)
task = task_factory([1, 2, 3, kfp.dsl.PipelineParam('aaa'), 4, 5, 6])

full_command_line = task.command + task.arguments
for arg in full_command_line:
self.assertNotIn('PipelineParam', arg)

0 comments on commit 7ee500f

Please sign in to comment.