From 3be4e6b155e129c5f5b6e64005222d75b1c1390a Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Thu, 21 Jul 2022 11:45:20 -0700 Subject: [PATCH] StructuredDatasetTransformerEngine should derive default protocol from raw output prefix (#1107) Signed-off-by: Yee Hing Tong --- flytekit/core/data_persistence.py | 4 +++ flytekit/types/structured/basic_dfs.py | 17 ++++------- .../types/structured/structured_dataset.py | 27 +++++++++++++---- .../flytekitplugins/polars/sd_transformers.py | 8 ++--- .../flytekitplugins/spark/sd_transformers.py | 4 +-- .../unit/core/test_data_persistence.py | 1 + tests/flytekit/unit/core/test_dataclass.py | 29 +++++++++++++++++++ .../unit/core/test_structured_dataset.py | 24 +++++++++++++-- 8 files changed, 88 insertions(+), 26 deletions(-) create mode 100644 tests/flytekit/unit/core/test_dataclass.py diff --git a/flytekit/core/data_persistence.py b/flytekit/core/data_persistence.py index 9cde937eb8..e69e3f6476 100644 --- a/flytekit/core/data_persistence.py +++ b/flytekit/core/data_persistence.py @@ -318,6 +318,10 @@ def __init__( self._raw_output_prefix = raw_output_prefix self._data_config = data_config if data_config else DataConfig.auto() + @property + def raw_output_prefix(self) -> str: + return self._raw_output_prefix + @property def data_config(self) -> DataConfig: return self._data_config diff --git a/flytekit/types/structured/basic_dfs.py b/flytekit/types/structured/basic_dfs.py index 39557d33ad..2b91e3422e 100644 --- a/flytekit/types/structured/basic_dfs.py +++ b/flytekit/types/structured/basic_dfs.py @@ -105,14 +105,9 @@ def decode( return pq.read_table(local_dir) -for protocol in [LOCAL, S3]: - StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol)) - StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol)) - StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol)) - StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol)) - -# Don't override the default for GCS. -StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(GCS), default_for_type=False) -StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(GCS), default_for_type=False) -StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(GCS), default_for_type=False) -StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(GCS), default_for_type=False) +# Don't override default protocol +for protocol in [LOCAL, S3, GCS]: + StructuredDatasetTransformerEngine.register(PandasToParquetEncodingHandler(protocol), default_for_type=False) + StructuredDatasetTransformerEngine.register(ParquetToPandasDecodingHandler(protocol), default_for_type=False) + StructuredDatasetTransformerEngine.register(ArrowToParquetEncodingHandler(protocol), default_for_type=False) + StructuredDatasetTransformerEngine.register(ParquetToArrowDecodingHandler(protocol), default_for_type=False) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index af60599f0c..d37b9aff37 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -471,10 +471,7 @@ def to_literal( # 3. This is the third and probably most common case. The python StructuredDataset object wraps a dataframe # that we will need to invoke an encoder for. Figure out which encoder to call and invoke it. df_type = type(python_val.dataframe) - if python_val.uri is None: - protocol = self.DEFAULT_PROTOCOLS[df_type] - else: - protocol = protocol_prefix(python_val.uri) + protocol = self._protocol_from_type_or_prefix(ctx, df_type, python_val.uri) return self.encode( ctx, python_val, @@ -485,13 +482,31 @@ def to_literal( ) # Otherwise assume it's a dataframe instance. Wrap it with some defaults - fmt = self.DEFAULT_FORMATS[python_type] - protocol = self.DEFAULT_PROTOCOLS[python_type] + if python_type in self.DEFAULT_FORMATS: + fmt = self.DEFAULT_FORMATS[python_type] + else: + logger.debug(f"No default format for type {python_type}, using system default.") + fmt = StructuredDataset.DEFAULT_FILE_FORMAT + protocol = self._protocol_from_type_or_prefix(ctx, python_type) meta = StructuredDatasetMetadata(structured_dataset_type=expected.structured_dataset_type if expected else None) sd = StructuredDataset(dataframe=python_val, metadata=meta) return self.encode(ctx, sd, python_type, protocol, fmt, sdt) + def _protocol_from_type_or_prefix(self, ctx: FlyteContext, df_type: Type, uri: Optional[str] = None) -> str: + """ + Get the protocol from the default, if missing, then look it up from the uri if provided, if not then look + up from the provided context's file access. + """ + if df_type in self.DEFAULT_PROTOCOLS: + return self.DEFAULT_PROTOCOLS[df_type] + else: + protocol = protocol_prefix(uri or ctx.file_access.raw_output_prefix) + logger.debug( + f"No default protocol for type {df_type} found, using {protocol} from output prefix {ctx.file_access.raw_output_prefix}" + ) + return protocol + def encode( self, ctx: FlyteContext, diff --git a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py index 1a667fe699..f79e97efb5 100644 --- a/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py +++ b/plugins/flytekit-polars/flytekitplugins/polars/sd_transformers.py @@ -62,12 +62,10 @@ def decode( return pl.read_parquet(path) -for protocol in [LOCAL, S3]: +for protocol in [LOCAL, S3, GCS]: StructuredDatasetTransformerEngine.register( - PolarsDataFrameToParquetEncodingHandler(protocol), default_for_type=True + PolarsDataFrameToParquetEncodingHandler(protocol), default_for_type=False ) StructuredDatasetTransformerEngine.register( - ParquetToPolarsDataFrameDecodingHandler(protocol), default_for_type=True + ParquetToPolarsDataFrameDecodingHandler(protocol), default_for_type=False ) -StructuredDatasetTransformerEngine.register(PolarsDataFrameToParquetEncodingHandler(GCS), default_for_type=False) -StructuredDatasetTransformerEngine.register(ParquetToPolarsDataFrameDecodingHandler(GCS), default_for_type=False) diff --git a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py index cd451fa080..651860d4b7 100644 --- a/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py +++ b/plugins/flytekit-spark/flytekitplugins/spark/sd_transformers.py @@ -49,5 +49,5 @@ def decode( for protocol in ["/", "s3"]: - StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler(protocol), default_for_type=True) - StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler(protocol), default_for_type=True) + StructuredDatasetTransformerEngine.register(SparkToParquetEncodingHandler(protocol), default_for_type=False) + StructuredDatasetTransformerEngine.register(ParquetToSparkDecodingHandler(protocol), default_for_type=False) diff --git a/tests/flytekit/unit/core/test_data_persistence.py b/tests/flytekit/unit/core/test_data_persistence.py index 9cf867ea88..e61350a7ed 100644 --- a/tests/flytekit/unit/core/test_data_persistence.py +++ b/tests/flytekit/unit/core/test_data_persistence.py @@ -5,6 +5,7 @@ def test_get_random_remote_path(): fp = FileAccessProvider("/tmp", "s3://my-bucket") path = fp.get_random_remote_path() assert path.startswith("s3://my-bucket") + assert fp.raw_output_prefix == "s3://my-bucket" def test_is_remote(): diff --git a/tests/flytekit/unit/core/test_dataclass.py b/tests/flytekit/unit/core/test_dataclass.py new file mode 100644 index 0000000000..db49d2312c --- /dev/null +++ b/tests/flytekit/unit/core/test_dataclass.py @@ -0,0 +1,29 @@ +from dataclasses import dataclass +from typing import List + +from dataclasses_json import dataclass_json + +from flytekit.core.task import task +from flytekit.core.workflow import workflow + + +def test_dataclass(): + @dataclass_json + @dataclass + class AppParams(object): + snapshotDate: str + region: str + preprocess: bool + listKeys: List[str] + + @task + def t1() -> AppParams: + ap = AppParams(snapshotDate="4/5/2063", region="us-west-3", preprocess=False, listKeys=["a", "b"]) + return ap + + @workflow + def wf() -> AppParams: + return t1() + + res = wf() + assert res.region == "us-west-3" diff --git a/tests/flytekit/unit/core/test_structured_dataset.py b/tests/flytekit/unit/core/test_structured_dataset.py index 9982918290..773d30b5ae 100644 --- a/tests/flytekit/unit/core/test_structured_dataset.py +++ b/tests/flytekit/unit/core/test_structured_dataset.py @@ -1,3 +1,4 @@ +import tempfile import typing import pytest @@ -5,6 +6,7 @@ import flytekit.configuration from flytekit.configuration import Image, ImageConfig from flytekit.core.context_manager import FlyteContext, FlyteContextManager +from flytekit.core.data_persistence import FileAccessProvider from flytekit.core.type_engine import TypeEngine from flytekit.models import literals from flytekit.models.literals import StructuredDatasetMetadata @@ -20,8 +22,8 @@ from flytekit import kwtypes, task from flytekit.types.structured.structured_dataset import ( + LOCAL, PARQUET, - S3, StructuredDataset, StructuredDatasetDecoder, StructuredDatasetEncoder, @@ -336,7 +338,7 @@ def test_to_python_value_without_incoming_columns(): def test_format_correct(): class TempEncoder(StructuredDatasetEncoder): def __init__(self): - super().__init__(pd.DataFrame, S3, "avro") + super().__init__(pd.DataFrame, LOCAL, "avro") def encode( self, @@ -375,3 +377,21 @@ def t1() -> Annotated[StructuredDataset, "avro"]: return StructuredDataset(dataframe=df) assert t1().file_format == "avro" + + +def test_protocol_detection(): + # We've don't register defaults to the transformer engine + assert pd.DataFrame not in StructuredDatasetTransformerEngine.DEFAULT_PROTOCOLS + e = StructuredDatasetTransformerEngine() + ctx = FlyteContextManager.current_context() + protocol = e._protocol_from_type_or_prefix(ctx, pd.DataFrame) + assert protocol == "/" + + with tempfile.TemporaryDirectory() as tmp_dir: + fs = FileAccessProvider(local_sandbox_dir=tmp_dir, raw_output_prefix="s3://fdsa") + ctx2 = ctx.with_file_access(fs).build() + protocol = e._protocol_from_type_or_prefix(ctx2, pd.DataFrame) + assert protocol == "s3" + + protocol = e._protocol_from_type_or_prefix(ctx2, pd.DataFrame, "bq://foo") + assert protocol == "bq"