diff --git a/dev-requirements.in b/dev-requirements.in index 3cb16d8d3b5..8c968cd54bf 100644 --- a/dev-requirements.in +++ b/dev-requirements.in @@ -26,6 +26,7 @@ torch<=1.12.1; python_version<'3.11' # pytorch 2 supports python 3.11 torch<=2.0.0; python_version>='3.11' or platform_system!='Windows' +pillow scikit-learn types-protobuf types-croniter diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index 7b4f17eda55..c30a919f2f0 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -874,6 +874,8 @@ def lazy_import_transformers(cls): register_bigquery_handlers() if is_imported("numpy"): from flytekit.types import numpy # noqa: F401 + if is_imported("PIL"): + from flytekit.types.file import image # noqa: F401 @classmethod def to_literal_type(cls, python_type: Type) -> LiteralType: diff --git a/flytekit/experimental/eager_function.py b/flytekit/experimental/eager_function.py index 264d0d641ab..e0f252e3122 100644 --- a/flytekit/experimental/eager_function.py +++ b/flytekit/experimental/eager_function.py @@ -531,7 +531,7 @@ async def wrapper(*args, **kws): return task( wrapper, secret_requests=secret_requests, - disable_deck=False, + enable_deck=True, execution_mode=PythonFunctionTask.ExecutionBehavior.EAGER, **kwargs, ) diff --git a/flytekit/types/file/image.py b/flytekit/types/file/image.py new file mode 100644 index 00000000000..d26389cf993 --- /dev/null +++ b/flytekit/types/file/image.py @@ -0,0 +1,82 @@ +import pathlib +import typing +from typing import Type + +import PIL.Image + +from flytekit.core.context_manager import FlyteContext +from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError +from flytekit.models.core import types as _core_types +from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar +from flytekit.models.types import LiteralType + +T = typing.TypeVar("T") + + +class PILImageTransformer(TypeTransformer[T]): + """ + TypeTransformer that supports PIL.Image as a native type. + """ + + FILE_FORMAT = "PIL.Image" + + def __init__(self): + super().__init__(name="PIL.Image", t=PIL.Image.Image) + + def get_literal_type(self, t: Type[T]) -> LiteralType: + return LiteralType( + blob=_core_types.BlobType( + format=self.FILE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ) + + def to_literal( + self, ctx: FlyteContext, python_val: PIL.Image.Image, python_type: Type[T], expected: LiteralType + ) -> Literal: + + meta = BlobMetadata( + type=_core_types.BlobType( + format=self.FILE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE + ) + ) + + local_path = ctx.file_access.get_random_local_path() + ".png" + pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True) + python_val.save(local_path) + + remote_path = ctx.file_access.get_random_remote_path(local_path) + ctx.file_access.put_data(local_path, remote_path, is_multipart=False) + return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))) + + def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> PIL.Image.Image: + try: + uri = lv.scalar.blob.uri + except AttributeError: + raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}") + + local_path = ctx.file_access.get_random_local_path() + ctx.file_access.get_data(uri, local_path, is_multipart=False) + + return PIL.Image.open(local_path) + + def guess_python_type(self, literal_type: LiteralType) -> Type[T]: + if ( + literal_type.blob is not None + and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE + and literal_type.blob.format == self.FILE_FORMAT + ): + return PIL.Image.Image + + raise ValueError(f"Transformer {self} cannot reverse {literal_type}") + + def to_html(self, ctx: FlyteContext, python_val: PIL.Image.Image, expected_python_type: Type[T]) -> str: + import base64 + from io import BytesIO + + buffered = BytesIO() + python_val.save(buffered, format="PNG") + img_base64 = base64.b64encode(buffered.getvalue()).decode() + return f'Rendered Image' + + +TypeEngine.register(PILImageTransformer()) diff --git a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py index c090ea6a46a..b71ee832bf4 100644 --- a/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py +++ b/plugins/flytekit-deck-standard/flytekitplugins/deck/renderer.py @@ -6,7 +6,7 @@ if TYPE_CHECKING: import markdown import pandas as pd - import PIL + import PIL.Image import plotly.express as px else: pd = lazy_module("pandas") diff --git a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py index bced35a6df6..8b053428466 100644 --- a/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py +++ b/plugins/flytekit-kf-pytorch/tests/test_elastic_task.py @@ -120,7 +120,7 @@ def test_deck(start_method: str) -> None: @task( task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method), - disable_deck=False, + enable_deck=True, ) def train(): import os diff --git a/plugins/flytekit-mlflow/README.md b/plugins/flytekit-mlflow/README.md index 6cbee9cf59b..6a9a794a9f3 100644 --- a/plugins/flytekit-mlflow/README.md +++ b/plugins/flytekit-mlflow/README.md @@ -15,7 +15,7 @@ from flytekit import task, workflow from flytekitplugins.mlflow import mlflow_autolog import mlflow -@task(disable_deck=False) +@task(enable_deck=True) @mlflow_autolog(framework=mlflow.keras) def train_model(): ... diff --git a/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py index 613cbfcd769..3605c7ee2f3 100644 --- a/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py +++ b/plugins/flytekit-mlflow/tests/test_mlflow_tracking.py @@ -6,7 +6,7 @@ from flytekit import task -@task(disable_deck=False) +@task(enable_deck=True) @mlflow_autolog(framework=mlflow.keras) def train_model(epochs: int): fashion_mnist = tf.keras.datasets.fashion_mnist diff --git a/tests/flytekit/unit/types/file/test_image.py b/tests/flytekit/unit/types/file/test_image.py new file mode 100644 index 00000000000..8f7469a4575 --- /dev/null +++ b/tests/flytekit/unit/types/file/test_image.py @@ -0,0 +1,22 @@ +import PIL.Image + +from flytekit import task, workflow + + +@task(enable_deck=True) +def t1() -> PIL.Image.Image: + return PIL.Image.new("L", (100, 100), "black") + + +@task +def t2(im: PIL.Image.Image) -> PIL.Image.Image: + return im + + +@workflow +def wf(): + t2(im=t1()) + + +def test_image_transformer(): + wf()