Skip to content

Commit

Permalink
Add image transformer (flyteorg#1901)
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin Su <[email protected]>
  • Loading branch information
pingsutw authored and ringohoffman committed Nov 24, 2023
1 parent 17dcf9e commit da32137
Show file tree
Hide file tree
Showing 9 changed files with 112 additions and 5 deletions.
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ torch<=1.12.1; python_version<'3.11'
# pytorch 2 supports python 3.11
torch<=2.0.0; python_version>='3.11' or platform_system!='Windows'

pillow
scikit-learn
types-protobuf
types-croniter
Expand Down
2 changes: 2 additions & 0 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,8 @@ def lazy_import_transformers(cls):
register_bigquery_handlers()
if is_imported("numpy"):
from flytekit.types import numpy # noqa: F401
if is_imported("PIL"):
from flytekit.types.file import image # noqa: F401

@classmethod
def to_literal_type(cls, python_type: Type) -> LiteralType:
Expand Down
2 changes: 1 addition & 1 deletion flytekit/experimental/eager_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ async def wrapper(*args, **kws):
return task(
wrapper,
secret_requests=secret_requests,
disable_deck=False,
enable_deck=True,
execution_mode=PythonFunctionTask.ExecutionBehavior.EAGER,
**kwargs,
)
Expand Down
82 changes: 82 additions & 0 deletions flytekit/types/file/image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pathlib
import typing
from typing import Type

import PIL.Image

from flytekit.core.context_manager import FlyteContext
from flytekit.core.type_engine import TypeEngine, TypeTransformer, TypeTransformerFailedError
from flytekit.models.core import types as _core_types
from flytekit.models.literals import Blob, BlobMetadata, Literal, Scalar
from flytekit.models.types import LiteralType

T = typing.TypeVar("T")


class PILImageTransformer(TypeTransformer[T]):
"""
TypeTransformer that supports PIL.Image as a native type.
"""

FILE_FORMAT = "PIL.Image"

def __init__(self):
super().__init__(name="PIL.Image", t=PIL.Image.Image)

def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(
blob=_core_types.BlobType(
format=self.FILE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)

def to_literal(
self, ctx: FlyteContext, python_val: PIL.Image.Image, python_type: Type[T], expected: LiteralType
) -> Literal:

meta = BlobMetadata(
type=_core_types.BlobType(
format=self.FILE_FORMAT, dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE
)
)

local_path = ctx.file_access.get_random_local_path() + ".png"
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)
python_val.save(local_path)

remote_path = ctx.file_access.get_random_remote_path(local_path)
ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
return Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path)))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> PIL.Image.Image:
try:
uri = lv.scalar.blob.uri
except AttributeError:
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=False)

return PIL.Image.open(local_path)

def guess_python_type(self, literal_type: LiteralType) -> Type[T]:
if (
literal_type.blob is not None
and literal_type.blob.dimensionality == _core_types.BlobType.BlobDimensionality.SINGLE
and literal_type.blob.format == self.FILE_FORMAT
):
return PIL.Image.Image

raise ValueError(f"Transformer {self} cannot reverse {literal_type}")

def to_html(self, ctx: FlyteContext, python_val: PIL.Image.Image, expected_python_type: Type[T]) -> str:
import base64
from io import BytesIO

buffered = BytesIO()
python_val.save(buffered, format="PNG")
img_base64 = base64.b64encode(buffered.getvalue()).decode()
return f'<img src="data:image/png;base64,{img_base64}" alt="Rendered Image" />'


TypeEngine.register(PILImageTransformer())
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
if TYPE_CHECKING:
import markdown
import pandas as pd
import PIL
import PIL.Image
import plotly.express as px
else:
pd = lazy_module("pandas")
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def test_deck(start_method: str) -> None:

@task(
task_config=Elastic(nnodes=1, nproc_per_node=world_size, start_method=start_method),
disable_deck=False,
enable_deck=True,
)
def train():
import os
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-mlflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ from flytekit import task, workflow
from flytekitplugins.mlflow import mlflow_autolog
import mlflow

@task(disable_deck=False)
@task(enable_deck=True)
@mlflow_autolog(framework=mlflow.keras)
def train_model():
...
Expand Down
2 changes: 1 addition & 1 deletion plugins/flytekit-mlflow/tests/test_mlflow_tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from flytekit import task


@task(disable_deck=False)
@task(enable_deck=True)
@mlflow_autolog(framework=mlflow.keras)
def train_model(epochs: int):
fashion_mnist = tf.keras.datasets.fashion_mnist
Expand Down
22 changes: 22 additions & 0 deletions tests/flytekit/unit/types/file/test_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import PIL.Image

from flytekit import task, workflow


@task(enable_deck=True)
def t1() -> PIL.Image.Image:
return PIL.Image.new("L", (100, 100), "black")


@task
def t2(im: PIL.Image.Image) -> PIL.Image.Image:
return im


@workflow
def wf():
t2(im=t1())


def test_image_transformer():
wf()

0 comments on commit da32137

Please sign in to comment.