From a98c4c0d24570a4cbd9fefa5e99d6d3c40a38c7a Mon Sep 17 00:00:00 2001 From: Yee Hing Tong Date: Fri, 18 Oct 2024 11:08:40 -0700 Subject: [PATCH] add an error to catch a common mistake Signed-off-by: Yee Hing Tong --- flytekit/types/structured/structured_dataset.py | 7 +++++++ tests/flytekit/unit/core/test_type_engine.py | 10 ++++++++++ 2 files changed, 17 insertions(+) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 52361f0008..12a1b1ca28 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -607,6 +607,13 @@ def to_literal( # In case it's a FlyteSchema sdt = StructuredDatasetType(format=self.DEFAULT_FORMATS.get(python_type, GENERIC_FORMAT)) + if issubclass(python_type, StructuredDataset) and not isinstance(python_val, StructuredDataset): + # Catch a common mistake + raise TypeTransformerFailedError( + f"Expected a StructuredDataset instance, but got {type(python_val)} instead." + f" Did you forget to wrap your dataframe in a StructuredDataset instance?" + ) + if expected and expected.structured_dataset_type: sdt = StructuredDatasetType( columns=expected.structured_dataset_type.columns, diff --git a/tests/flytekit/unit/core/test_type_engine.py b/tests/flytekit/unit/core/test_type_engine.py index c63199663b..eaf0d699fe 100644 --- a/tests/flytekit/unit/core/test_type_engine.py +++ b/tests/flytekit/unit/core/test_type_engine.py @@ -3692,3 +3692,13 @@ def test_structured_dataset_collection(): lv = TypeEngine.to_literal(FlyteContext.current_context(), [[StructuredDataset(df)]], WineTypeListList, lt) assert lv is not None + + +@pytest.mark.skipif("pandas" not in sys.modules, reason="Pandas is not installed.") +def test_structured_dataset_mismatch(): + import pandas as pd + + df = pd.DataFrame({"alcohol": [1.0, 2.0], "malic_acid": [2.0, 3.0]}) + transformer = TypeEngine.get_transformer(StructuredDataset) + with pytest.raises(TypeTransformerFailedError): + transformer.to_literal(FlyteContext.current_context(), df, StructuredDataset, TypeEngine.to_literal_type(StructuredDataset))