Skip to content

Commit

Permalink
requested changes
Browse files Browse the repository at this point in the history
Signed-off-by: esad <[email protected]>
  • Loading branch information
peridotml committed May 8, 2023
1 parent 1c0717d commit 7b0b142
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@
record_outputs
"""

from .task import NotebookTask, read_input, record_outputs
from .task import NotebookTask, read_flytedirectory, read_flytefile, read_structureddataset, record_outputs
61 changes: 42 additions & 19 deletions plugins/flytekit-papermill/flytekitplugins/papermill/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from google.protobuf import text_format as _text_format
from nbconvert import HTMLExporter

from flytekit import FlyteContext, PythonInstanceTask
from flytekit import FlyteContext, PythonInstanceTask, StructuredDataset
from flytekit.configuration import SerializationSettings
from flytekit.core import utils
from flytekit.core.context_manager import ExecutionParameters
Expand All @@ -22,7 +22,8 @@
from flytekit.loggers import logger
from flytekit.models import task as task_models
from flytekit.models.literals import Literal, LiteralMap
from flytekit.types.file import HTMLPage, PythonNotebook
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile, HTMLPage, PythonNotebook

T = typing.TypeVar("T")

Expand Down Expand Up @@ -259,8 +260,14 @@ def execute(self, **kwargs) -> Any:
"""
logger.info(f"Hijacking the call for task-type {self.task_type}, to call notebook.")
# Execute Notebook via Papermill.
seralized_kwargs = serialize_inputs(**kwargs)
pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=seralized_kwargs, log_output=self._stream_logs) # type: ignore

for k, v in kwargs.items():
if isinstance(v, (FlyteFile, FlyteDirectory)):
kwargs[k] = save_literal_to_file(v)
elif isinstance(v, StructuredDataset):
kwargs[k] = save_literal_to_file(v)

pm.execute_notebook(self._notebook_path, self.output_notebook_path, parameters=kwargs, log_output=self._stream_logs) # type: ignore

outputs = self.extract_outputs(self.output_notebook_path)
self.render_nb_html(self.output_notebook_path, self.rendered_output_path)
Expand Down Expand Up @@ -313,28 +320,23 @@ def record_outputs(**kwargs) -> str:
return LiteralMap(literals=m).to_flyte_idl()


def serialize_inputs(**kwargs) -> Dict[str, str]:
def save_literal_to_file(input: Any) -> str:
"""
Serializes the inputs and saves separate files. Returns a dictionary
Serializes an input
"""
if kwargs is None:
return {}

outputs = {}
ctx = FlyteContext.current_context()
for k, v in kwargs.items():
expected = TypeEngine.to_literal_type(type(v))
lit = TypeEngine.to_literal(ctx, python_type=type(v), python_val=v, expected=expected)

tmp = tempfile.mktemp(suffix="bin")
utils.write_proto_to_file(lit.to_flyte_idl(), tmp)
outputs[k] = tmp
expected = TypeEngine.to_literal_type(type(input))
lit = TypeEngine.to_literal(ctx, python_type=type(input), python_val=input, expected=expected)

return outputs
tmp_file = tempfile.mktemp(suffix="bin")
utils.write_proto_to_file(lit.to_flyte_idl(), tmp_file)
return tmp_file


def read_input(path: str, dtype: T) -> T:

"""
Reads a Flyte literal from a file
"""
if type(path) == dtype:
return path

Expand All @@ -343,3 +345,24 @@ def read_input(path: str, dtype: T) -> T:
ctx = FlyteContext.current_context()
python_value = TypeEngine.to_python_value(ctx, lit, dtype)
return python_value


def read_flytefile(path: str) -> T:
"""
Use this method to read a FlyteFile literal from a file.
"""
return read_input(path=path, dtype=FlyteFile)


def read_flytedirectory(path: str) -> T:
"""
Use this method to read a FlyteDirectory literal from a file.
"""
return read_input(path=path, dtype=FlyteDirectory)


def read_structureddataset(path: str) -> T:
"""
Use this method to read a StructuredDataset literal from a file.
"""
return read_input(path=path, dtype=StructuredDataset)
40 changes: 38 additions & 2 deletions plugins/flytekit-papermill/tests/test_task.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import datetime
import os
import tempfile

import pandas as pd
from flytekitplugins.papermill import NotebookTask
from flytekitplugins.pod import Pod
from kubernetes.client import V1Container, V1PodSpec

import flytekit
from flytekit import kwtypes
from flytekit import StructuredDataset, kwtypes, task, workflow
from flytekit.configuration import Image, ImageConfig
from flytekit.types.file import PythonNotebook
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile, PythonNotebook

from .testdata.datatype import X

Expand Down Expand Up @@ -134,3 +137,36 @@ def test_notebook_pod_task():
nb.get_command(serialization_settings)
== nb.get_k8s_pod(serialization_settings).pod_spec["containers"][0]["args"]
)


def test_flyte_types():
@task
def create_file() -> FlyteFile:
tmp_file = tempfile.mktemp()
with open(tmp_file, "w") as f:
f.write("abc")
return FlyteFile(path=tmp_file)

@task
def create_dir() -> FlyteDirectory:
tmp_dir = tempfile.mkdtemp()
with open(os.path.join(tmp_dir, "file.txt"), "w") as f:
f.write("abc")
return FlyteDirectory(path=tmp_dir)

@task
def create_sd() -> StructuredDataset:
df = pd.DataFrame({"a": [1, 2], "b": [3, 4]})
return StructuredDataset(dataframe=df)

ff = create_file()
fd = create_dir()
sd = create_sd()

nb_name = "nb-types"
nb_types = NotebookTask(
name="test",
notebook_path=_get_nb_path(nb_name, abs=False),
inputs=kwtypes(ff=FlyteFile, fd=FlyteDirectory, sd=StructuredDataset),
)
nb_types.execute(ff=ff, fd=fd, sd=sd)
14 changes: 1 addition & 13 deletions plugins/flytekit-papermill/tests/testdata/nb-simple.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,6 @@
"pi = 3.14"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from flytekitplugins.papermill import record_outputs, read_input\n",
"\n",
"pi = read_input(pi, float)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -46,6 +33,7 @@
},
"outputs": [],
"source": [
"from flytekitplugins.papermill import record_outputs\n",
"record_outputs(square=out)"
]
},
Expand Down

0 comments on commit 7b0b142

Please sign in to comment.