Skip to content

Commit

Permalink
[BUG] Convert protobuf to literal as remote exec (#2925)
Browse files Browse the repository at this point in the history
* fix: Convert protobuf to literal as remote exec

Signed-off-by: JiaWei Jiang <[email protected]>

* test: Test conversion of DC attrs under generic IDL

Signed-off-by: JiaWei Jiang <[email protected]>

* Add type determination for ListValue condition

Signed-off-by: JiaWei Jiang <[email protected]>

* Handle all the other conditions except for ListValue

Signed-off-by: JiaWei Jiang <[email protected]>

* Remove parquet file suffix

Signed-off-by: JiaWei Jiang <[email protected]>

* Skip the test if pd isn't installed

Signed-off-by: JiaWei Jiang <[email protected]>

* test: Add mini dc attr tests without pd

Signed-off-by: JiaWei Jiang <[email protected]>

* Lint to remove trailing whitespaces

Signed-off-by: JiaWei Jiang <[email protected]>

---------

Signed-off-by: JiaWei Jiang <[email protected]>
  • Loading branch information
JiangJiaWei1103 authored Nov 20, 2024
1 parent b04bc8d commit 2e40e76
Show file tree
Hide file tree
Showing 3 changed files with 326 additions and 5 deletions.
11 changes: 9 additions & 2 deletions flytekit/core/promise.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ async def resolve_attr_path_in_promise(p: Promise) -> Promise:

curr_val = curr_val.value.literals[attr]
used += 1

# Scalar is always the leaf. There can't be a collection or map in a scalar.
if type(curr_val.value) is _literals_models.Scalar:
break
Expand All @@ -167,7 +168,6 @@ async def resolve_attr_path_in_promise(p: Promise) -> Promise:
This works correctly in remote execution.
Issue Link: https://github.com/flyteorg/flyte/issues/5959
"""

st = curr_val.value.value
new_st = resolve_attr_path_in_pb_struct(st, attr_path=p.attr_path[used:])
literal_type = TypeEngine.to_literal_type(type(new_st))
Expand Down Expand Up @@ -204,7 +204,14 @@ def resolve_attr_path_in_dict(d: dict, attr_path: List[Union[str, int]]) -> Any:
return curr_val


def resolve_attr_path_in_pb_struct(st: _struct.Struct, attr_path: List[Union[str, int]]) -> _struct.Struct:
def resolve_attr_path_in_pb_struct(
st: _struct.Struct, attr_path: List[Union[str, int]]
) -> Union[_struct.Struct, _struct.ListValue]:
"""
Resolves the protobuf struct (e.g. dataclass) with attribute path.
Note that the return type can be google.protobuf.struct_pb2.Struct or google.protobuf.struct_pb2.ListValue.
"""
curr_val = st
for attr in attr_path:
if attr not in curr_val:
Expand Down
28 changes: 25 additions & 3 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,12 +972,34 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
return LiteralType(simple=SimpleType.STRUCT, metadata={ProtobufTransformer.PB_FIELD_KEY: self.tag(t)})

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
struct = Struct()
"""
Convert the protobuf struct to literal.
This conversion supports two types of python_val:
1. google.protobuf.struct_pb2.Struct: A dictionary-like message
2. google.protobuf.struct_pb2.ListValue: An ordered collection of values
For details, please refer to the following issue:
https://github.com/flyteorg/flyte/issues/5959
Because the remote handling works without errors, we implement conversion with the logic as below:
https://github.com/flyteorg/flyte/blob/a87585ab7cbb6a047c76d994b3f127c4210070fd/flytepropeller/pkg/controller/nodes/attr_path_resolver.go#L72-L106
"""
try:
struct.update(_MessageToDict(cast(Message, python_val)))
if type(python_val) == _struct.ListValue:
literals = []
for v in python_val:
literal_type = TypeEngine.to_literal_type(type(v))
# Recursively convert python native values to literals
literal = TypeEngine.to_literal(ctx, v, type(v), literal_type)
literals.append(literal)
return Literal(collection=LiteralCollection(literals=literals))
else:
struct = Struct()
struct.update(_MessageToDict(cast(Message, python_val)))
return Literal(scalar=Scalar(generic=struct))
except Exception:
raise TypeTransformerFailedError("Failed to convert to generic protobuf struct")
return Literal(scalar=Scalar(generic=struct))

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
if not (lv and lv.scalar and lv.scalar.generic is not None):
Expand Down
292 changes: 292 additions & 0 deletions tests/flytekit/unit/core/test_generic_idl_protobuf_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,292 @@
import sys
import tempfile
from pathlib import Path
from dataclasses import dataclass, field
from typing import Dict, List

import pytest
from flytekit import task, workflow
from flytekit.types.file import FlyteFile
from flytekit.types.directory import FlyteDirectory
from flytekit.types.structured import StructuredDataset
from flytekit.types.schema import FlyteSchema
from flytekit import task, workflow
from enum import Enum


@pytest.fixture
def local_tmp_file():
with tempfile.NamedTemporaryFile(mode="w+", suffix=".txt") as tmp_file:
tmp_file.write("Hello World!")
tmp_file.flush()
tmp_file_path = tmp_file.name

yield tmp_file_path


@pytest.fixture
def local_tmp_dir():
with tempfile.TemporaryDirectory() as tmp_dir:
with open(Path(tmp_dir) / "example.txt", "w") as f:
f.write("Hello World!")

yield tmp_dir


@pytest.fixture
def local_pqt_file():
return Path(__file__).parents[2] / "integration/remote/workflows/basic/data/df.parquet"


@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.")
def test_all_dc_attrs(local_tmp_file, local_tmp_dir, local_pqt_file):
import os
import pandas as pd

# Enable generating protobuf struct in the generic IDL
# Please refer to https://github.com/flyteorg/flyte/issues/5959
os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "True"


class Status(Enum):
PENDING = "pending"
APPROVED = "approved"
REJECTED = "rejected"

@dataclass
class InnerDC:
a: int = -1
b: float = 2.1
c: str = "Hello, Flyte"
d: bool = False
e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2])
f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_tmp_file)])
g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]])
h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}])
i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False})
j: Dict[int, FlyteFile] = field(default_factory=lambda: {
0: FlyteFile(local_tmp_file), 1: FlyteFile(local_tmp_file), -1: FlyteFile(local_tmp_file)
})
k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]})
l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}})
m: dict = field(default_factory=lambda: {"key": "value"})
n: FlyteFile = field(default_factory=lambda: FlyteFile(local_tmp_file))
o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_tmp_dir))
enum_status: Status = field(default=Status.PENDING)
sd: StructuredDataset = field(default_factory=lambda: StructuredDataset(uri=local_pqt_file, file_format="parquet"))
fsc: FlyteSchema = field(default_factory=lambda: FlyteSchema(local_path=local_pqt_file))

@dataclass
class DC:
a: int = -1
b: float = 2.1
c: str = "Hello, Flyte"
d: bool = False
e: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2])
f: List[FlyteFile] = field(default_factory=lambda: [FlyteFile(local_tmp_file)])
g: List[List[int]] = field(default_factory=lambda: [[0], [1], [-1]])
h: List[Dict[int, bool]] = field(default_factory=lambda: [{0: False}, {1: True}, {-1: True}])
i: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False})
j: Dict[int, FlyteFile] = field(default_factory=lambda: {
0: FlyteFile(local_tmp_file), 1: FlyteFile(local_tmp_file), -1: FlyteFile(local_tmp_file)
})
k: Dict[int, List[int]] = field(default_factory=lambda: {0: [0, 1, -1]})
l: Dict[int, Dict[int, int]] = field(default_factory=lambda: {1: {-1: 0}})
m: dict = field(default_factory=lambda: {"key": "value"})
n: FlyteFile = field(default_factory=lambda: FlyteFile(local_tmp_file))
o: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(local_tmp_dir))
enum_status: Status = field(default=Status.PENDING)
sd: StructuredDataset = field(default_factory=lambda: StructuredDataset(uri=local_pqt_file, file_format="parquet"))
fsc: FlyteSchema = field(default_factory=lambda: FlyteSchema(local_path=local_pqt_file))

# Define a nested dataclass
inner_dc: InnerDC = field(default_factory=lambda: InnerDC())

@task
def t_dc(dc: DC) -> DC:
assert isinstance(dc, DC), "dc is not of type DC"

return dc

@task
def t_inner(inner_dc: InnerDC) -> InnerDC:
assert isinstance(inner_dc, InnerDC), "inner_dc is not of type InnerDC"

# f: List[FlyteFile]
for ff in inner_dc.f:
assert isinstance(ff, FlyteFile), "Expected FlyteFile"
with open(ff, "r") as f:
assert f.read() == EXPECTED_FILE_CONTENT, "File content mismatch in f"

# j: Dict[int, FlyteFile]
for _, ff in inner_dc.j.items():
assert isinstance(ff, FlyteFile), "Expected FlyteFile in j"
with open(ff, "r") as f:
assert f.read() == EXPECTED_FILE_CONTENT, "File content mismatch in j"

# n: FlyteFile
assert isinstance(inner_dc.n, FlyteFile), "n is not FlyteFile"
with open(inner_dc.n, "r") as f:
assert f.read() == EXPECTED_FILE_CONTENT, "File content mismatch in n"

# o: FlyteDirectory
assert isinstance(inner_dc.o, FlyteDirectory), "o is not FlyteDirectory"
assert not inner_dc.o.downloaded, "o should not be downloaded initially"
with open(os.path.join(inner_dc.o, "example.txt"), "r") as fh:
assert fh.read() == EXPECTED_FILE_CONTENT, "File content mismatch in o"
assert inner_dc.o.downloaded, "o should be marked as downloaded after access"

assert inner_dc.enum_status == Status.PENDING, "enum_status does not match"
assert isinstance(inner_dc.sd, StructuredDataset), "sd is not StructuredDataset"
assert isinstance(inner_dc.fsc, FlyteSchema), "fsc is not FlyteSchema"
print("All checks in InnerDC passed")

return inner_dc

@task
def t_test_all_attributes(
a: int, b: float, c: str, d: bool,
e: List[int], f: List[FlyteFile], g: List[List[int]], h: List[Dict[int, bool]],
i: Dict[int, bool], j: Dict[int, FlyteFile], k: Dict[int, List[int]], l: Dict[int, Dict[int, int]],
m: dict, n: FlyteFile, o: FlyteDirectory, enum_status: Status,
sd: StructuredDataset, fsc: FlyteSchema
) -> None:
# Strict type checks for simple types
assert isinstance(a, int), f"a is not int, it's {type(a)}"
assert a == -1
assert isinstance(b, float), f"b is not float, it's {type(b)}"
assert isinstance(c, str), f"c is not str, it's {type(c)}"
assert isinstance(d, bool), f"d is not bool, it's {type(d)}"

# Strict type checks for List[int]
assert isinstance(e, list) and all(isinstance(i, int) for i in e), "e is not List[int]"

# Strict type checks for List[FlyteFile]
assert isinstance(f, list) and all(isinstance(i, FlyteFile) for i in f), "f is not List[FlyteFile]"

# Strict type checks for List[List[int]]
assert isinstance(g, list) and all(
isinstance(i, list) and all(isinstance(j, int) for j in i) for i in g), "g is not List[List[int]]"

# Strict type checks for List[Dict[int, bool]]
assert isinstance(h, list) and all(
isinstance(i, dict) and all(isinstance(k, int) and isinstance(v, bool) for k, v in i.items()) for i in h
), "h is not List[Dict[int, bool]]"

# Strict type checks for Dict[int, bool]
assert isinstance(i, dict) and all(
isinstance(k, int) and isinstance(v, bool) for k, v in i.items()), "i is not Dict[int, bool]"

# Strict type checks for Dict[int, FlyteFile]
assert isinstance(j, dict) and all(
isinstance(k, int) and isinstance(v, FlyteFile) for k, v in j.items()), "j is not Dict[int, FlyteFile]"

# Strict type checks for Dict[int, List[int]]
assert isinstance(k, dict) and all(
isinstance(k, int) and isinstance(v, list) and all(isinstance(i, int) for i in v) for k, v in
k.items()), "k is not Dict[int, List[int]]"

# Strict type checks for Dict[int, Dict[int, int]]
assert isinstance(l, dict) and all(
isinstance(k, int) and isinstance(v, dict) and all(
isinstance(sub_k, int) and isinstance(sub_v, int) for sub_k, sub_v in v.items())
for k, v in l.items()), "l is not Dict[int, Dict[int, int]]"

# Strict type check for a generic dict
assert isinstance(m, dict), "m is not dict"

# Strict type check for FlyteFile
assert isinstance(n, FlyteFile), "n is not FlyteFile"

# Strict type check for FlyteDirectory
assert isinstance(o, FlyteDirectory), "o is not FlyteDirectory"

# # Strict type check for Enum
assert isinstance(enum_status, Status), "enum_status is not Status"

assert isinstance(sd, StructuredDataset), "sd is not StructuredDataset"
print("sd:", sd.open(pd.DataFrame).all())

assert isinstance(fsc, FlyteSchema), "fsc is not FlyteSchema"
print("fsc: ", fsc.open().all())

print("All attributes passed strict type checks.")


@workflow
def wf(dc: DC) -> None:
new_dc = t_dc(dc=dc)
t_inner(new_dc.inner_dc)

# Test outer dc
t_test_all_attributes(
a=new_dc.a, b=new_dc.b, c=new_dc.c, d=new_dc.d,
e=new_dc.e, f=new_dc.f, g=new_dc.g, h=new_dc.h,
i=new_dc.i, j=new_dc.j, k=new_dc.k, l=new_dc.l,
m=new_dc.m, n=new_dc.n, o=new_dc.o, enum_status=new_dc.enum_status,
sd=new_dc.sd, fsc=new_dc.fsc
)

# Test inner dc
t_test_all_attributes(
a=new_dc.inner_dc.a, b=new_dc.inner_dc.b, c=new_dc.inner_dc.c, d=new_dc.inner_dc.d,
e=new_dc.inner_dc.e, f=new_dc.inner_dc.f, g=new_dc.inner_dc.g, h=new_dc.inner_dc.h,
i=new_dc.inner_dc.i, j=new_dc.inner_dc.j, k=new_dc.inner_dc.k, l=new_dc.inner_dc.l,
m=new_dc.inner_dc.m, n=new_dc.inner_dc.n, o=new_dc.inner_dc.o, enum_status=new_dc.inner_dc.enum_status,
sd=new_dc.inner_dc.sd, fsc=new_dc.inner_dc.fsc
)


EXPECTED_FILE_CONTENT = "Hello World!"

wf(dc=DC())


def test_mini_dc_attrs():
"""
Test dc attributes which focuses only on protobuf structs.
This test doesn't depend on pandas.
"""
import os

# Enable generating protobuf struct in the generic IDL
# Please refer to https://github.com/flyteorg/flyte/issues/5959
os.environ["FLYTE_USE_OLD_DC_FORMAT"] = "True"

@dataclass
class InnerDC:
a: int = -1
b: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2])
c: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False})

@dataclass
class DC:
a: int = -1
b: List[int] = field(default_factory=lambda: [0, 1, 2, -1, -2])
c: Dict[int, bool] = field(default_factory=lambda: {0: False, 1: True, -1: False})

# Define a nested dataclass
inner_dc: InnerDC = field(default_factory=lambda: InnerDC())

@task
def t_test_attrs(a: int, b: List[int], c: Dict[int, bool]) -> None:
assert isinstance(a, int), f"a is not int, it's {type(a)}"
assert a == -1

assert isinstance(b, list) and all(isinstance(i, int) for i in b), "b is not List[int]"

assert isinstance(c, dict) and all(
isinstance(k, int) and isinstance(v, bool) for k, v in c.items()
), "c is not Dict[int, bool]"

@workflow
def wf(dc: DC) -> None:
# Test outer dc
t_test_attrs(dc.a, dc.b, dc.c)

# Test inner dc
t_test_attrs(dc.inner_dc.a, dc.inner_dc.b, dc.inner_dc.c)


wf(dc=DC())

0 comments on commit 2e40e76

Please sign in to comment.