From b176c1c935ac7de49aa97ac1131d89168d09aef5 Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 18 Oct 2024 11:51:27 -0700 Subject: [PATCH] Catch mistake in structured dataset (#2834) Signed-off-by: Yee Hing Tong --- flytekit/types/structured/structured_dataset.py | 7 +++++++ tests/flytekit/unit/core/test_type_engine.py | 13 +++++++++++++ 2 files changed, 20 insertions(+) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 52361f00083..12a1b1ca28e 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -607,6 +607,13 @@ def to_literal( # In case it's a FlyteSchema sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT)) + if issubclass(python_type, StructuredDataset) and not isinstance(python_val, StructuredDataset): + # Catch a common mistake + raise TypeTransformerFailedError( + f"Expected a StructuredDataset instance, but got {type(python_val)} instead." + f" Did you forget to wrap your dataframe in a StructuredDataset instance?" + ) + if expected and expected.structured_dataset_type: sdt = StructuredDatasetType( columns=expected.structured_dataset_type.columns, diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index c63199663b6..2f8e5a8a3e9 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3692,3 +3692,16 @@ def test_structured_dataset_collection(): lv = TypeEngine.to_literal(FlyteContext.current_context(), [[StructuredDataset(df)]], WineTypeListList, lt) assert lv is not None + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_structured_dataset_mismatch(): + import pandas as pd + + df = pd.DataFrame({"alcohol": [1.0, 2.0], "malic_acid": [2.0, 3.0]}) + transformer = TypeEngine.get_transformer(StructuredDataset) + with pytest.raises(TypeTransformerFailedError): + transformer.to_literal(FlyteContext.current_context(), df, StructuredDataset, TypeEngine.to_literal_type(StructuredDataset)) + + with pytest.raises(TypeTransformerFailedError): + TypeEngine.to_literal(FlyteContext.current_context(), df, StructuredDataset, TypeEngine.to_literal_type(StructuredDataset))