Skip to content

Commit

Permalink
StructuredDatasetTransformerEngine should derive default protocol fro…
Browse files Browse the repository at this point in the history
…m raw output prefix (#1107)


Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor committed Aug 2, 2022
1 parent c984814 commit 7fcda56
Show file tree
Hide file tree
Showing 8 changed files with 88 additions and 26 deletions.
4 changes: 4 additions & 0 deletions flytekit/core/data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,10 @@ def __init__(
self._raw_output_prefix = raw_output_prefix
self._data_config = data_config if data_config else DataConfig.auto()

@property
def raw_output_prefix(self) -> str:
return self._raw_output_prefix

@property
def data_config(self) -> DataConfig:
return self._data_config
Expand Down
17 changes: 6 additions & 11 deletions flytekit/types/structured/basic_dfs.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,9 @@ def decode(
return pq.read_table(local_dir)


for protocol in [LOCAL, S3]:
StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol))
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol))
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol))
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol))

# Don't override the default for GCS.
StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(GCS), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(GCS), default_for_type=False)
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(GCS), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(GCS), default_for_type=False)
# Don't override default protocol
for protocol in [LOCAL, S3, GCS]:
StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), default_for_type=False)
27 changes: 21 additions & 6 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,10 +471,7 @@ def to_literal(
# 3. This is the third and probably most common case. The python StructuredDataset object wraps a dataframe
# that we will need to invoke an encoder for. Figure out which encoder to call and invoke it.
df_type = type(python_val.dataframe)
if python_val.uri is None:
protocol = self.DEFAULT_PROTOCOLS[df_type]
else:
protocol = protocol_prefix(python_val.uri)
protocol = self._protocol_from_type_or_prefix(ctx, df_type, python_val.uri)
return self.encode(
ctx,
python_val,
Expand All @@ -485,13 +482,31 @@ def to_literal(
)

# Otherwise assume it's a dataframe instance. Wrap it with some defaults
fmt = self.DEFAULT_FORMATS[python_type]
protocol = self.DEFAULT_PROTOCOLS[python_type]
if python_type in self.DEFAULT_FORMATS:
fmt = self.DEFAULT_FORMATS[python_type]
else:
logger.debug(f"No default format for type {python_type}, using system default.")
fmt = StructuredDataset.DEFAULT_FILE_FORMAT
protocol = self._protocol_from_type_or_prefix(ctx, python_type)
meta = StructuredDatasetMetadata(structured_dataset_type=expected.structured_dataset_type if expected else None)

sd = StructuredDataset(dataframe=python_val, metadata=meta)
return self.encode(ctx, sd, python_type, protocol, fmt, sdt)

def _protocol_from_type_or_prefix(self, ctx: FlyteContext, df_type: Type, uri: Optional[str] = None) -> str:
"""
Get the protocol from the default, if missing, then look it up from the uri if provided, if not then look
up from the provided context's file access.
"""
if df_type in self.DEFAULT_PROTOCOLS:
return self.DEFAULT_PROTOCOLS[df_type]
else:
protocol = protocol_prefix(uri or ctx.file_access.raw_output_prefix)
logger.debug(
f"No default protocol for type {df_type} found, using {protocol} from output prefix {ctx.file_access.raw_output_prefix}"
)
return protocol

def encode(
self,
ctx: FlyteContext,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ def decode(
return pl.read_parquet(path)


for protocol in [LOCAL, S3]:
for protocol in [LOCAL, S3, GCS]:
StructuredDatasetTransformerEngine.register(
PolarsDataFrameToParquetEncodingHandler(protocol), default_for_type=True
PolarsDataFrameToParquetEncodingHandler(protocol), default_for_type=False
)
StructuredDatasetTransformerEngine.register(
ParquetToPolarsDataFrameDecodingHandler(protocol), default_for_type=True
ParquetToPolarsDataFrameDecodingHandler(protocol), default_for_type=False
)
StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler(GCS), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToPolarsDataFrameDecodingHandler(GCS), default_for_type=False)
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ def decode(


for protocol in ["/", "s3"]:
StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler(protocol), default_for_type=True)
StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler(protocol), default_for_type=False)
StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler(protocol), default_for_type=False)
1 change: 1 addition & 0 deletions tests/flytekit/unit/core/test_data_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ def test_get_random_remote_path():
fp = FileAccessProvider("/tmp", "s3://my-bucket")
path = fp.get_random_remote_path()
assert path.startswith("s3://my-bucket")
assert fp.raw_output_prefix == "s3://my-bucket"


def test_is_remote():
Expand Down
29 changes: 29 additions & 0 deletions tests/flytekit/unit/core/test_dataclass.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from dataclasses import dataclass
from typing import List

from dataclasses_json import dataclass_json

from flytekit.core.task import task
from flytekit.core.workflow import workflow


def test_dataclass():
@dataclass_json
@dataclass
class AppParams(object):
snapshotDate: str
region: str
preprocess: bool
listKeys: List[str]

@task
def t1() -> AppParams:
ap = AppParams(snapshotDate="4/5/2063", region="us-west-3", preprocess=False, listKeys=["a", "b"])
return ap

@workflow
def wf() -> AppParams:
return t1()

res = wf()
assert res.region == "us-west-3"
24 changes: 22 additions & 2 deletions tests/flytekit/unit/core/test_structured_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import tempfile
import typing

import pytest

import flytekit.configuration
from flytekit.configuration import Image, ImageConfig
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.type_engine import TypeEngine
from flytekit.models import literals
from flytekit.models.literals import StructuredDatasetMetadata
Expand All @@ -20,8 +22,8 @@

from flytekit import kwtypes, task
from flytekit.types.structured.structured_dataset import (
LOCAL,
PARQUET,
S3,
StructuredDataset,
StructuredDatasetDecoder,
StructuredDatasetEncoder,
Expand Down Expand Up @@ -336,7 +338,7 @@ def test_to_python_value_without_incoming_columns():
def test_format_correct():
class TempEncoder(StructuredDatasetEncoder):
def __init__(self):
super().__init__(pd.DataFrame, S3, "avro")
super().__init__(pd.DataFrame, LOCAL, "avro")

def encode(
self,
Expand Down Expand Up @@ -375,3 +377,21 @@ def t1() -> Annotated[StructuredDataset, "avro"]:
return StructuredDataset(dataframe=df)

assert t1().file_format == "avro"


def test_protocol_detection():
# We've don't register defaults to the transformer engine
assert pd.DataFrame not in StructuredDatasetTransformerEngine.DEFAULT_PROTOCOLS
e = StructuredDatasetTransformerEngine()
ctx = FlyteContextManager.current_context()
protocol = e._protocol_from_type_or_prefix(ctx, pd.DataFrame)
assert protocol == "/"

with tempfile.TemporaryDirectory() as tmp_dir:
fs = FileAccessProvider(local_sandbox_dir=tmp_dir, raw_output_prefix="s3://fdsa")
ctx2 = ctx.with_file_access(fs).build()
protocol = e._protocol_from_type_or_prefix(ctx2, pd.DataFrame)
assert protocol == "s3"

protocol = e._protocol_from_type_or_prefix(ctx2, pd.DataFrame, "bq://foo")
assert protocol == "bq"

0 comments on commit 7fcda56

Please sign in to comment.