From bed93db280bc3c6bf74360f4788ee0493a2e98c8 Mon Sep 17 00:00:00 2001 From: esad Date: Mon, 8 May 2023 17:43:13 -0400 Subject: [PATCH] requested changes Signed-off-by: esad --- .../flytekitplugins/papermill/__init__.py | 2 +- .../flytekitplugins/papermill/task.py | 61 +++++++++++++------ plugins/flytekit-papermill/tests/test_task.py | 40 +++++++++++- .../tests/testdata/nb-simple.ipynb | 14 +---- 4 files changed, 82 insertions(+), 35 deletions(-) diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py b/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py index 1172235d750..c4d54df7213 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/__init__.py @@ -11,4 +11,4 @@ record_outputs """ -from .task import NotebookTask, read_input, record_outputs +from .task import NotebookTask, read_flytedirectory, read_flytefile, read_structureddataset, record_outputs diff --git a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py index 4731c195fb1..c6eb29b42e2 100644 --- a/plugins/flytekit-papermill/flytekitplugins/papermill/task.py +++ b/plugins/flytekit-papermill/flytekitplugins/papermill/task.py @@ -13,7 +13,7 @@ from google.protobuf import text_format as _text_format from nbconvert import HTMLExporter -from flytekit import FlyteContext, PythonInstanceTask +from flytekit import FlyteContext, PythonInstanceTask, StructuredDataset from flytekit.configuration import SerializationSettings from flytekit.core import utils from flytekit.core.context_manager import ExecutionParameters @@ -22,7 +22,8 @@ from flytekit.loggers import logger from flytekit.models import task as task_models from flytekit.models.literals import Literal, LiteralMap -from flytekit.types.file import HTMLPage, PythonNotebook +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile, HTMLPage, PythonNotebook T = typing.TypeVar("T") @@ -259,8 +260,14 @@ def execute(self, **kwargs) -> Any: """ logger.info(f"Hijacking the call for task-type {self.task_type}, to call notebook.") # Execute Notebook via Papermill. - seralized_kwargs = serialize_inputs(**kwargs) - pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=seralized_kwargs, log_output=self._stream_logs) # type: ignore + + for k, v in kwargs.items(): + if isinstance(v, (FlyteFile, FlyteDirectory)): + kwargs[k] = save_literal_to_file(v) + elif isinstance(v, StructuredDataset): + kwargs[k] = save_literal_to_file(v) + + pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs, log_output=self._stream_logs) # type: ignore outputs = self.extract_outputs(self.output_notebook_path) self.render_nb_html(self.output_notebook_path, self.rendered_output_path) @@ -313,28 +320,23 @@ def record_outputs(**kwargs) -> str: return LiteralMap(literals=m).to_flyte_idl() -def serialize_inputs(**kwargs) -> Dict[str, str]: +def save_literal_to_file(input: Any) -> str: """ - Serializes the inputs and saves separate files. Returns a dictionary + Serializes an input """ - if kwargs is None: - return {} - - outputs = {} ctx = FlyteContext.current_context() - for k, v in kwargs.items(): - expected = TypeEngine.to_literal_type(type(v)) - lit = TypeEngine.to_literal(ctx, python_type=type(v), python_val=v, expected=expected) - - tmp = tempfile.mktemp(suffix="bin") - utils.write_proto_to_file(lit.to_flyte_idl(), tmp) - outputs[k] = tmp + expected = TypeEngine.to_literal_type(type(input)) + lit = TypeEngine.to_literal(ctx, python_type=type(input), python_val=input, expected=expected) - return outputs + tmp_file = tempfile.mktemp(suffix="bin") + utils.write_proto_to_file(lit.to_flyte_idl(), tmp_file) + return tmp_file def read_input(path: str, dtype: T) -> T: - + """ + Reads a Flyte literal from a file + """ if type(path) == dtype: return path @@ -343,3 +345,24 @@ def read_input(path: str, dtype: T) -> T: ctx = FlyteContext.current_context() python_value = TypeEngine.to_python_value(ctx, lit, dtype) return python_value + + +def read_flytefile(path: str) -> T: + """ + Use this method to read a FlyteFile literal from a file. + """ + return read_input(path=path, dtype=FlyteFile) + + +def read_flytedirectory(path: str) -> T: + """ + Use this method to read a FlyteDirectory literal from a file. + """ + return read_input(path=path, dtype=FlyteDirectory) + + +def read_structureddataset(path: str) -> T: + """ + Use this method to read a StructuredDataset literal from a file. + """ + return read_input(path=path, dtype=StructuredDataset) diff --git a/plugins/flytekit-papermill/tests/test_task.py b/plugins/flytekit-papermill/tests/test_task.py index 1947d09445c..407d5faac69 100644 --- a/plugins/flytekit-papermill/tests/test_task.py +++ b/plugins/flytekit-papermill/tests/test_task.py @@ -1,14 +1,17 @@ import datetime import os +import tempfile +import pandas as pd from flytekitplugins.papermill import NotebookTask from flytekitplugins.pod import Pod from kubernetes.client import V1Container, V1PodSpec import flytekit -from flytekit import kwtypes +from flytekit import StructuredDataset, kwtypes, task, workflow from flytekit.configuration import Image, ImageConfig -from flytekit.types.file import PythonNotebook +from flytekit.types.directory import FlyteDirectory +from flytekit.types.file import FlyteFile, PythonNotebook from .testdata.datatype import X @@ -134,3 +137,36 @@ def test_notebook_pod_task(): nb.get_command(serialization_settings) == nb.get_k8s_pod(serialization_settings).pod_spec["containers"][0]["args"] ) + + +def test_flyte_types(): + @task + def create_file() -> FlyteFile: + tmp_file = tempfile.mktemp() + with open(tmp_file, "w") as f: + f.write("abc") + return FlyteFile(path=tmp_file) + + @task + def create_dir() -> FlyteDirectory: + tmp_dir = tempfile.mkdtemp() + with open(os.path.join(tmp_dir, "file.txt"), "w") as f: + f.write("abc") + return FlyteDirectory(path=tmp_dir) + + @task + def create_sd() -> StructuredDataset: + df = pd.DataFrame({"a": [1, 2], "b": [3, 4]}) + return StructuredDataset(dataframe=df) + + ff = create_file() + fd = create_dir() + sd = create_sd() + + nb_name = "nb-types" + nb_types = NotebookTask( + name="test", + notebook_path=_get_nb_path(nb_name, abs=False), + inputs=kwtypes(ff=FlyteFile, fd=FlyteDirectory, sd=StructuredDataset), + ) + nb_types.execute(ff=ff, fd=fd, sd=sd) diff --git a/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb b/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb index 11130e16d2b..1ad7aaed4ac 100644 --- a/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb +++ b/plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb @@ -13,19 +13,6 @@ "pi = 3.14" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "tags": [] - }, - "outputs": [], - "source": [ - "from flytekitplugins.papermill import record_outputs, read_input\n", - "\n", - "pi = read_input(pi, float)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -46,6 +33,7 @@ }, "outputs": [], "source": [ + "from flytekitplugins.papermill import record_outputs\n", "record_outputs(square=out)" ] },