diff --git a/src/snowflake/snowpark/types.py b/src/snowflake/snowpark/types.py index 59a1a6fdae9..f78b35e1f97 100644 --- a/src/snowflake/snowpark/types.py +++ b/src/snowflake/snowpark/types.py @@ -683,6 +683,8 @@ def __init__( ) else: self.structured = structured or False + + self.fields = [] for field in fields or []: self.add(field) diff --git a/tests/integ/scala/test_datatype_suite.py b/tests/integ/scala/test_datatype_suite.py index a1bd1d48acd..1004190b00e 100644 --- a/tests/integ/scala/test_datatype_suite.py +++ b/tests/integ/scala/test_datatype_suite.py @@ -856,9 +856,9 @@ def test_structured_dtypes_cast(structured_type_session, structured_type_support pytest.skip("Test requires structured type support.") expected_semi_schema = StructType( [ - StructField("ARR", ArrayType(StringType()), nullable=True), - StructField("MAP", MapType(StringType(), StringType()), nullable=True), - StructField("OBJ", MapType(StringType(), StringType()), nullable=True), + StructField("ARR", ArrayType(), nullable=True), + StructField("MAP", MapType(), nullable=True), + StructField("OBJ", MapType(), nullable=True), ] ) expected_structured_schema = StructType( diff --git a/tests/unit/test_datatype_mapper.py b/tests/unit/test_datatype_mapper.py index af8b9cd3c1c..44e84df0a43 100644 --- a/tests/unit/test_datatype_mapper.py +++ b/tests/unit/test_datatype_mapper.py @@ -123,9 +123,17 @@ def test_to_sql(): assert ( to_sql([1, "2", 3.5], ArrayType()) == "PARSE_JSON('[1, \"2\", 3.5]') :: ARRAY" ) + assert ( + to_sql([1, 2, 3], ArrayType(IntegerType(), structured=True)) + == "PARSE_JSON('[1, 2, 3]') :: ARRAY(INT)" + ) assert ( to_sql({"'": '"'}, MapType()) == 'PARSE_JSON(\'{"\'\'": "\\\\""}\') :: OBJECT' ) + assert ( + to_sql({"'": '"'}, MapType(StringType(), structured=True)) + == 'PARSE_JSON(\'{"\'\'": "\\\\""}\') :: MAP(STRING, STRING)' + ) assert to_sql([{1: 2}], ArrayType()) == "PARSE_JSON('[{\"1\": 2}]') :: ARRAY" assert to_sql({1: [2]}, MapType()) == "PARSE_JSON('{\"1\": [2]}') :: OBJECT" diff --git a/tests/unit/test_types.py b/tests/unit/test_types.py index d5ffc9757f6..db5355d1cec 100644 --- a/tests/unit/test_types.py +++ b/tests/unit/test_types.py @@ -683,6 +683,7 @@ def {func_name}(x, y {datatype_str} = {annotated_value}) -> None: @pytest.mark.parametrize( "value_str,datatype,expected_value", [ + (None, None, None), ("1", IntegerType(), 1), ("True", BooleanType(), True), ("1.0", FloatType(), 1.0),