Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SDK - Lightweight - Added support for file inputs #2207

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 59 additions & 9 deletions sdk/python/kfp/components/_python_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
'func_to_component_text',
'get_default_base_image',
'set_default_base_image',
'InputPath',
'InputTextFile',
'InputBinaryFile',
]

from ._yaml_utils import dump_yaml
Expand All @@ -31,16 +34,34 @@

T = TypeVar('T')

#OutputFile[GcsPath[Gzipped[Text]]]

# InputPath(list) or InputPath('JsonObject')

class InputFile(Generic[T], str):
pass
class InputPath:
'''When creating component from function, InputPath should be used as function parameter annotation to tell the system to pass the *data file path* to the function instead of passing the actual data.'''
def __init__(self, type=None):
self.type = type


class InputTextFile:
'''When creating component from function, InputTextFile should be used as function parameter annotation to tell the system to pass the *text data stream* object (`io.TextIOWrapper`) to the function instead of passing the actual data.'''
def __init__(self, type=None):
self.type = type


class InputBinaryFile:
'''When creating component from function, InputBinaryFile should be used as function parameter annotation to tell the system to pass the *binary data stream* object (`io.BytesIO`) to the function instead of passing the actual data.'''
def __init__(self, type=None):
self.type = type


#OutputFile[GcsPath[Gzipped[Text]]]


class OutputFile(Generic[T], str):
pass


#TODO: Replace this image name with another name once people decide what to replace it with.
_default_base_image='tensorflow/tensorflow:1.13.2-py3'

Expand Down Expand Up @@ -181,7 +202,13 @@ def annotation_to_type_struct(annotation):
return type_name

for parameter in parameters:
type_struct = annotation_to_type_struct(parameter.annotation)
parameter_annotation = parameter.annotation
passing_style = None
if isinstance(parameter_annotation, (InputPath, InputTextFile, InputBinaryFile)):
passing_style = type(parameter_annotation)
parameter_annotation = parameter_annotation.type
# TODO: Fix the input names: "number_file_path" parameter should be exposed as "number" input
type_struct = annotation_to_type_struct(parameter_annotation)
#TODO: Humanize the input/output names

input_spec = InputSpec(
Expand All @@ -192,7 +219,7 @@ def annotation_to_type_struct(annotation):
input_spec.optional = True
if parameter.default is not None:
input_spec.default = serialize_value(parameter.default, type_struct)

input_spec._passing_style = passing_style
inputs.append(input_spec)

#Analyzing the return type annotations.
Expand Down Expand Up @@ -275,6 +302,19 @@ def get_deserializer_and_register_definitions(type_name):
return deserializer_code_str
return 'str'

pre_func_definitions = set()
def get_argparse_type_for_input_file(passing_style):
if passing_style is InputPath:
pre_func_definitions.add(inspect.getsource(InputPath))
return 'str'
elif passing_style is InputTextFile:
pre_func_definitions.add(inspect.getsource(InputTextFile))
return "argparse.FileType('rt')"
elif passing_style is InputBinaryFile:
pre_func_definitions.add(inspect.getsource(InputBinaryFile))
return "argparse.FileType('rb')"
return None

def get_serializer_and_register_definitions(type_name) -> str:
if type_name in type_name_to_serializer:
serializer_func = type_name_to_serializer[type_name]
Expand All @@ -300,19 +340,24 @@ def get_serializer_and_register_definitions(type_name) -> str:
line = '_parser.add_argument("{param_flag}", dest="{param_var}", type={param_type}, required={is_required}, default=argparse.SUPPRESS)'.format(
param_flag=param_flag,
param_var=input.name,
param_type=get_deserializer_and_register_definitions(input.type),
param_type=get_argparse_type_for_input_file(input._passing_style) or get_deserializer_and_register_definitions(input.type),
is_required=str(is_required),
)
arg_parse_code_lines.append(line)

if input._passing_style in [InputPath, InputTextFile, InputBinaryFile]:
arguments_for_input = [param_flag, InputPathPlaceholder(input.name)]
else:
arguments_for_input = [param_flag, InputValuePlaceholder(input.name)]

if is_required:
arguments.append(param_flag)
arguments.append(InputValuePlaceholder(input.name))
arguments.extend(arguments_for_input)
else:
arguments.append(
IfPlaceholder(
IfPlaceholderStructure(
condition=IsPresentPlaceholder(input.name),
then_value=[param_flag, InputValuePlaceholder(input.name)],
then_value=arguments_for_input,
)
)
)
Expand All @@ -336,6 +381,8 @@ def get_serializer_and_register_definitions(type_name) -> str:
serializer_call_str = get_serializer_and_register_definitions(output.type)
output_serialization_expression_strings.append(serializer_call_str)

pre_func_code = '\n'.join(list(pre_func_definitions))

arg_parse_code_lines = list(definitions) + arg_parse_code_lines

arg_parse_code_lines.extend([
Expand All @@ -345,6 +392,8 @@ def get_serializer_and_register_definitions(type_name) -> str:

full_source = \
'''\
{pre_func_code}

{extra_code}

{func_code}
Expand All @@ -371,6 +420,7 @@ def get_serializer_and_register_definitions(type_name) -> str:
'''.format(
func_name=func.__name__,
func_code=func_code,
pre_func_code=pre_func_code,
extra_code=extra_code,
arg_parse_code='\n'.join(arg_parse_code_lines),
output_serialization_code=',\n '.join(output_serialization_expression_strings),
Expand Down
70 changes: 69 additions & 1 deletion sdk/python/tests/components/test_python_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from pathlib import Path
from typing import Callable

import kfp
import kfp.components as comp

def add_two_numbers(a: float, b: float) -> float:
Expand All @@ -36,6 +37,21 @@ def components_local_output_dir_context(output_dir: str):
comp._components._outputs_dir = old_dir


@contextmanager
def components_override_input_output_dirs_context(inputs_dir: str, outputs_dir: str):
old_inputs_dir = comp._components._inputs_dir
old_outputs_dir = comp._components._outputs_dir
try:
if inputs_dir:
comp._components._inputs_dir = inputs_dir
if outputs_dir:
comp._components._outputs_dir = outputs_dir
yield
finally:
comp._components._inputs_dir = old_inputs_dir
comp._components._outputs_dir = old_outputs_dir


module_level_variable = 10


Expand Down Expand Up @@ -117,9 +133,16 @@ def helper_test_component_using_local_call(self, component_task_factory: Callabl
with tempfile.TemporaryDirectory() as temp_dir_name:
# Creating task from the component.
# We do it in a special context that allows us to control the output file locations.
with components_local_output_dir_context(temp_dir_name):
inputs_path = Path(temp_dir_name) / 'inputs'
outputs_path = Path(temp_dir_name) / 'outputs'
with components_override_input_output_dirs_context(str(inputs_path), str(outputs_path)):
task = component_task_factory(**arguments)

# Preparing input files
for input_name, input_file_path in (task.input_artifact_paths or {}).items():
Path(input_file_path).parent.mkdir(parents=True, exist_ok=True)
Path(input_file_path).write_text(str(arguments[input_name]))

# Constructing the full command-line from resolved command+args
full_command = task.command + task.arguments

Expand Down Expand Up @@ -494,6 +517,51 @@ def produce_list() -> list:
self.helper_test_component_using_local_call(task_factory, arguments={}, expected_output_values={'output': expected_output})


def test_input_path(self):
from kfp.components import InputPath
def consume_file_path(number_file_path: InputPath(int)) -> int:
with open(number_file_path) as f:
string_data = f.read()
return int(string_data)

task_factory = comp.func_to_container_op(consume_file_path)

self.assertEqual(task_factory.component_spec.inputs[0].type, 'Integer')

# TODO: Fix the input names: "number_file_path" parameter should be exposed as "number" input
self.helper_test_component_using_local_call(task_factory, arguments={'number_file_path': "42"}, expected_output_values={'output': '42'})


def test_input_text_file(self):
from kfp.components import InputTextFile
def consume_file_path(number_file: InputTextFile(int)) -> int:
string_data = number_file.read()
assert isinstance(string_data, str)
return int(string_data)

task_factory = comp.func_to_container_op(consume_file_path)

self.assertEqual(task_factory.component_spec.inputs[0].type, 'Integer')

# TODO: Fix the input names: "number_file" parameter should be exposed as "number" input
self.helper_test_component_using_local_call(task_factory, arguments={'number_file': "42"}, expected_output_values={'output': '42'})


def test_input_binary_file(self):
from kfp.components import InputBinaryFile
def consume_file_path(number_file: InputBinaryFile(int)) -> int:
bytes_data = number_file.read()
assert isinstance(bytes_data, bytes)
return int(bytes_data)

task_factory = comp.func_to_container_op(consume_file_path)

self.assertEqual(task_factory.component_spec.inputs[0].type, 'Integer')

# TODO: Fix the input names: "number_file" parameter should be exposed as "number" input
self.helper_test_component_using_local_call(task_factory, arguments={'number_file': "42"}, expected_output_values={'output': '42'})


def test_end_to_end_python_component_pipeline_compilation(self):
import kfp.components as comp

Expand Down