From 59303b2c9241307ed91157e489e433d8b14ed1b2 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 9 Dec 2024 17:24:37 +0200 Subject: [PATCH 1/5] add struct metadata type parsing --- src/firebolt/common/_types.py | 83 ++++++++++++++++++++++++++++------- tests/unit/db_conftest.py | 34 ++++++++++++-- 2 files changed, 98 insertions(+), 19 deletions(-) diff --git a/src/firebolt/common/_types.py b/src/firebolt/common/_types.py index 290c93c47e..79be19fb7b 100644 --- a/src/firebolt/common/_types.py +++ b/src/firebolt/common/_types.py @@ -5,7 +5,8 @@ from datetime import date, datetime, timezone from decimal import Decimal from enum import Enum -from typing import List, Optional, Sequence, Union +from io import StringIO +from typing import Any, Dict, List, Optional, Sequence, Union from sqlparse import parse as parse_sql # type: ignore from sqlparse.sql import ( # type: ignore @@ -62,8 +63,6 @@ def parse_datetime(datetime_string: str) -> datetime: # These definitions are required by PEP-249 Date = date -_AccountInfo = namedtuple("_AccountInfo", ["id", "version"]) - def DateFromTicks(t: int) -> date: # NOSONAR """Convert `ticks` to `date` for Firebolt DB.""" @@ -109,16 +108,25 @@ def Binary(value: str) -> bytes: # NOSONAR ) -class ARRAY: +class ExtendedType: + """Base type for all extended types in Firebolt (array, decimal, struct, etc.).""" + + @staticmethod + def is_valid_type(type_: Any) -> bool: + return type_ in _col_types or isinstance(type_, ExtendedType) + + def __hash__(self) -> int: + return hash(str(self)) + + +class ARRAY(ExtendedType): """Class for holding `array` column type information in Firebolt DB.""" __name__ = "Array" _prefix = "array(" - def __init__(self, subtype: Union[type, ARRAY, DECIMAL]): - assert (subtype in _col_types and subtype is not list) or isinstance( - subtype, (ARRAY, DECIMAL) - ), f"Invalid array subtype: {str(subtype)}" + def __init__(self, subtype: Union[type, ARRAY, DECIMAL, STRUCT]): + assert self.is_valid_type(subtype), f"Invalid array subtype: {str(subtype)}" self.subtype = subtype def __str__(self) -> str: @@ -130,7 +138,7 @@ def __eq__(self, other: object) -> bool: return other.subtype == self.subtype -class DECIMAL: +class DECIMAL(ExtendedType): """Class for holding `decimal` value information in Firebolt DB.""" __name__ = "Decimal" @@ -143,15 +151,28 @@ def __init__(self, precision: int, scale: int): def __str__(self) -> str: return f"Decimal({self.precision}, {self.scale})" - def __hash__(self) -> int: - return hash(str(self)) - def __eq__(self, other: object) -> bool: if not isinstance(other, DECIMAL): return NotImplemented return other.precision == self.precision and other.scale == self.scale +class STRUCT(ExtendedType): + __name__ = "Struct" + _prefix = "Struct(" + + def __init__(self, fields: Dict[str, Union[type, ARRAY, DECIMAL, STRUCT]]): + for name, type_ in fields.items(): + assert self.is_valid_type(type_), f"Invalid struct field type: {str(type_)}" + self.fields = fields + + def __str__(self) -> str: + return f"Struct({', '.join(f'{k}: {v}' for k, v in self.fields.items())})" + + def __eq__(self, other: Any) -> bool: + return isinstance(other, STRUCT) and other.fields == self.fields + + NULLABLE_SUFFIX = "null" @@ -206,7 +227,27 @@ def python_type(self) -> type: return types[self] -def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL]: # noqa: C901 +def split_struct_fields(raw_struct: str) -> List[str]: + balance = 0 + separator = "," + res = [] + current = StringIO() + for i, ch in enumerate(raw_struct): + if ch == "(": + balance += 1 + elif ch == ")": + balance -= 1 + elif ch == separator and balance == 0: + res.append(current.getvalue()) + current = StringIO() + continue + current.write(ch) + + res.append(current.getvalue()) + return res + + +def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL, STRUCT]: # noqa: C901 """Parse typename provided by query metadata into Python type.""" if not isinstance(raw_type, str): raise DataError(f"Invalid typename {str(raw_type)}: str expected") @@ -218,10 +259,20 @@ def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL]: # noqa: C901 try: prec_scale = raw_type[len(DECIMAL._prefix) : -1].split(",") precision, scale = int(prec_scale[0]), int(prec_scale[1]) + return DECIMAL(precision, scale) except (ValueError, IndexError): pass - else: - return DECIMAL(precision, scale) + # Handle structs + if raw_type.startswith(STRUCT._prefix) and raw_type.endswith(")"): + try: + fields_raw = split_struct_fields(raw_type[len(STRUCT._prefix) : -1]) + fields = {} + for f in fields_raw: + name, type_ = f.strip().split(" ", 1) + fields[name.strip()] = parse_type(type_.strip()) + return STRUCT(fields) + except ValueError: + pass # Handle nullable if raw_type.endswith(NULLABLE_SUFFIX): return parse_type(raw_type[: -len(NULLABLE_SUFFIX)].strip(" ")) @@ -247,7 +298,7 @@ def _parse_bytea(str_value: str) -> bytes: def parse_value( value: RawColType, - ctype: Union[type, ARRAY, DECIMAL], + ctype: Union[type, ARRAY, DECIMAL, STRUCT], ) -> ColType: """Provided raw value, and Python type; parses first into Python value.""" if value is None: diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index a5c76bad81..58c799228a 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -8,6 +8,7 @@ from pytest_httpx import HTTPXMock from firebolt.async_db.cursor import JSON_OUTPUT_FORMAT, ColType, Column +from firebolt.common._types import STRUCT from firebolt.db import ARRAY, DECIMAL from firebolt.utils.urls import GATEWAY_HOST_BY_ACCOUNT_NAME from tests.unit.response import Response @@ -482,7 +483,7 @@ def types_map() -> Dict[str, type]: "timestamp": datetime, "timestampntz": datetime, "timestamptz": datetime, - "Nothing null": str, + "Nothing": str, "Decimal(123, 4)": DECIMAL(123, 4), "Decimal(38,0)": DECIMAL(38, 0), # Invalid decimal format @@ -491,7 +492,34 @@ def types_map() -> Dict[str, type]: "SomeRandomNotExistingType": str, "bytea": bytes, } - array_types = {f"array({k})": ARRAY(v) for k, v in base_types.items()} + nullable_types = {f"{k} null": v for k, v in base_types.items()} + array_types = { + f"array({k})": ARRAY(v) + for k, v in (*base_types.items(), *nullable_types.items()) + } nullable_arrays = {f"{k} null": v for k, v in array_types.items()} nested_arrays = {f"array({k})": ARRAY(v) for k, v in array_types.items()} - return {**base_types, **array_types, **nullable_arrays, **nested_arrays} + + struct_keys, struct_fields = list( + zip(*base_types.items(), *nullable_types.items(), *array_types.items()) + ) + # Create column names by replacing invalid characters with underscores + trans = str.maketrans({ch: "_" for ch in " (),"}) + + struct_items = [f"{key.translate(trans)}_col {key}" for key in struct_keys] + struct_type = f"Struct({', '.join(struct_items)})" + struct_field_names = [f"{key.translate(trans)}_col" for key in struct_keys] + struct = {struct_type: STRUCT(dict(zip(struct_field_names, struct_fields)))} + nested_struct = { + f"Struct(s {struct_type} null)": STRUCT({"s": list(struct.values())[0]}) + } + + return { + **base_types, + **nullable_types, + **array_types, + **nullable_arrays, + **nested_arrays, + **struct, + **nested_struct, + } From b9296c6dadc11219534608d844cc94e203095dd3 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Mon, 9 Dec 2024 17:41:41 +0200 Subject: [PATCH 2/5] add struct value parsing --- src/firebolt/async_db/__init__.py | 1 + src/firebolt/common/_types.py | 6 ++++ src/firebolt/db/__init__.py | 1 + tests/unit/common/test_typing_parse.py | 49 ++++++++++++++++++++++++-- 4 files changed, 55 insertions(+), 2 deletions(-) diff --git a/src/firebolt/async_db/__init__.py b/src/firebolt/async_db/__init__.py index 35ff4a15a1..e492ddd816 100644 --- a/src/firebolt/async_db/__init__.py +++ b/src/firebolt/async_db/__init__.py @@ -8,6 +8,7 @@ NUMBER, ROWID, STRING, + STRUCT, Binary, Date, DateFromTicks, diff --git a/src/firebolt/common/_types.py b/src/firebolt/common/_types.py index 79be19fb7b..78b5f165ba 100644 --- a/src/firebolt/common/_types.py +++ b/src/firebolt/common/_types.py @@ -329,6 +329,12 @@ def parse_value( if isinstance(ctype, ARRAY): assert isinstance(value, list) return [parse_value(it, ctype.subtype) for it in value] + if isinstance(ctype, STRUCT): + assert isinstance(value, dict) + return { + name: parse_value(value.get(name), type_) + for name, type_ in ctype.fields.items() + } raise DataError(f"Unsupported data type returned: {ctype.__name__}") diff --git a/src/firebolt/db/__init__.py b/src/firebolt/db/__init__.py index fcf6ae53c8..b101a2a746 100644 --- a/src/firebolt/db/__init__.py +++ b/src/firebolt/db/__init__.py @@ -6,6 +6,7 @@ NUMBER, ROWID, STRING, + STRUCT, Binary, Date, DateFromTicks, diff --git a/tests/unit/common/test_typing_parse.py b/tests/unit/common/test_typing_parse.py index 789fab571a..971fccb216 100644 --- a/tests/unit/common/test_typing_parse.py +++ b/tests/unit/common/test_typing_parse.py @@ -5,14 +5,16 @@ from pytest import mark, raises -from firebolt.async_db import ( +from firebolt.common._types import ( ARRAY, DECIMAL, + STRUCT, DateFromTicks, TimeFromTicks, TimestampFromTicks, + parse_type, + parse_value, ) -from firebolt.common._types import parse_type, parse_value from firebolt.utils.exception import DataError, NotSupportedError @@ -242,6 +244,7 @@ def test_parse_decimal(value, expected) -> None: (None, None, datetime), (None, None, date), (None, None, ARRAY(int)), + ([{"a": 1}, {"a": 2}], [{"a": 1}, {"a": 2}], STRUCT({"a": int})), ], ) def test_parse_arrays(value, expected, type) -> None: @@ -306,3 +309,45 @@ def test_parse_value_bytes(value, expected, error) -> None: assert ( parse_value(value, bytes) == expected ), f"Error parsing bytes: provided {value}" + + +@mark.parametrize( + "value,expected,type_,error", + [ + ( + {"a": 1, "b": False}, + {"a": 1, "b": False}, + STRUCT({"a": int, "b": bool}), + None, + ), + ( + {"a": 1, "b": "a"}, + {"a": 1, "b": "1"}, + STRUCT({"a": int, "b": bool}), + DataError, + ), + ( + {"dt": "2021-12-31 23:59:59", "d": "2021-12-31"}, + {"dt": datetime(2021, 12, 31, 23, 59, 59), "d": date(2021, 12, 31)}, + STRUCT({"dt": datetime, "d": date}), + None, + ), + ( + {"a": 1, "s": {"b": "2021-12-31"}}, + {"a": 1, "s": {"b": date(2021, 12, 31)}}, + STRUCT({"a": int, "s": STRUCT({"b": date})}), + None, + ), + (None, None, STRUCT({"a": int, "b": bool}), None), + ({"a": [1, 2, 3]}, {"a": [1, 2, 3]}, STRUCT({"a": ARRAY(int)}), None), + ], +) +def test_parse_value_struct(value, expected, type_, error) -> None: + """parse_value parses all int values correctly.""" + if error: + with raises(error): + parse_value(value, type_) + else: + assert ( + parse_value(value, type_) == expected + ), f"Error parsing struct: provided {value}" From 8746b3235992e77e99e885b928a151936b746575 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 11 Dec 2024 15:48:29 +0200 Subject: [PATCH 3/5] add integration test --- src/firebolt/common/_types.py | 2 +- .../dbapi/async/V2/test_queries_async.py | 26 +++++++++ tests/integration/dbapi/conftest.py | 53 ++++++++++++++++++- .../integration/dbapi/sync/V2/test_queries.py | 26 +++++++++ tests/unit/db_conftest.py | 4 +- 5 files changed, 107 insertions(+), 4 deletions(-) diff --git a/src/firebolt/common/_types.py b/src/firebolt/common/_types.py index 78b5f165ba..943adc3ab9 100644 --- a/src/firebolt/common/_types.py +++ b/src/firebolt/common/_types.py @@ -159,7 +159,7 @@ def __eq__(self, other: object) -> bool: class STRUCT(ExtendedType): __name__ = "Struct" - _prefix = "Struct(" + _prefix = "struct(" def __init__(self, fields: Dict[str, Union[type, ARRAY, DECIMAL, STRUCT]]): for name, type_ in fields.items(): diff --git a/tests/integration/dbapi/async/V2/test_queries_async.py b/tests/integration/dbapi/async/V2/test_queries_async.py index a9a13dab4d..30c232bbfe 100644 --- a/tests/integration/dbapi/async/V2/test_queries_async.py +++ b/tests/integration/dbapi/async/V2/test_queries_async.py @@ -429,3 +429,29 @@ async def test_select_geography( select_geography_response, "Invalid data returned by fetchall", ) + + +async def test_select_struct( + connection: Connection, + setup_struct_query: str, + cleanup_struct_query: str, + select_struct_query: str, + select_struct_description: List[Column], + select_struct_response: List[ColType], +): + with connection.cursor() as c: + try: + await c.execute(setup_struct_query) + await c.execute(select_struct_query) + assert ( + c.description == select_struct_description + ), "Invalid description value" + res = await c.fetchall() + assert len(res) == 1, "Invalid data length" + assert_deep_eq( + res, + select_struct_response, + "Invalid data returned by fetchall", + ) + finally: + await c.execute(cleanup_struct_query) diff --git a/tests/integration/dbapi/conftest.py b/tests/integration/dbapi/conftest.py index eb0bb1b72b..9f485c9238 100644 --- a/tests/integration/dbapi/conftest.py +++ b/tests/integration/dbapi/conftest.py @@ -6,7 +6,7 @@ from pytest import fixture from firebolt.async_db.cursor import Column -from firebolt.common._types import ColType +from firebolt.common._types import STRUCT, ColType from firebolt.db import ARRAY, DECIMAL, Connection LOGGER = getLogger(__name__) @@ -209,3 +209,54 @@ def select_geography_description() -> List[Column]: @fixture def select_geography_response() -> List[ColType]: return [["0101000020E6100000FEFFFFFFFFFFEF3F000000000000F03F"]] + + +@fixture +def setup_struct_query() -> str: + return """ + SET advanced_mode=1; + SET enable_struct=1; + SET enable_create_table_v2=true; + SET enable_row_selection=true; + SET prevent_create_on_information_schema=true; + SET enable_create_table_with_struct_type=true; + DROP TABLE IF EXISTS test_struct; + DROP TABLE IF EXISTS test_struct_helper; + CREATE TABLE IF NOT EXISTS test_struct(id int not null, s struct(a array(int) not null, b datetime null) not null); + CREATE TABLE IF NOT EXISTS test_struct_helper(a array(int) not null, b datetime null); + INSERT INTO test_struct_helper(a, b) VALUES ([1, 2], '2019-07-31 01:01:01'); + INSERT INTO test_struct(id, s) SELECT 1, test_struct_helper FROM test_struct_helper; + """ + + +@fixture +def cleanup_struct_query() -> str: + return """ + DROP TABLE IF EXISTS test_struct; + DROP TABLE IF EXISTS test_struct_helper; + """ + + +@fixture +def select_struct_query() -> str: + return "SELECT test_struct FROM test_struct" + + +@fixture +def select_struct_description() -> List[Column]: + return [ + Column( + "test_struct", + STRUCT({"id": int, "s": STRUCT({"a": ARRAY(int), "b": datetime})}), + None, + None, + None, + None, + None, + ) + ] + + +@fixture +def select_struct_response() -> List[ColType]: + return [[{"id": 1, "s": {"a": [1, 2], "b": datetime(2019, 7, 31, 1, 1, 1)}}]] diff --git a/tests/integration/dbapi/sync/V2/test_queries.py b/tests/integration/dbapi/sync/V2/test_queries.py index cea1e84504..4324407ec6 100644 --- a/tests/integration/dbapi/sync/V2/test_queries.py +++ b/tests/integration/dbapi/sync/V2/test_queries.py @@ -512,3 +512,29 @@ def test_select_geography( select_geography_response, "Invalid data returned by fetchall", ) + + +def test_select_struct( + connection: Connection, + setup_struct_query: str, + cleanup_struct_query: str, + select_struct_query: str, + select_struct_description: List[Column], + select_struct_response: List[ColType], +): + with connection.cursor() as c: + try: + c.execute(setup_struct_query) + c.execute(select_struct_query) + assert ( + c.description == select_struct_description + ), "Invalid description value" + res = c.fetchall() + assert len(res) == 1, "Invalid data length" + assert_deep_eq( + res, + select_struct_response, + "Invalid data returned by fetchall", + ) + finally: + c.execute(cleanup_struct_query) diff --git a/tests/unit/db_conftest.py b/tests/unit/db_conftest.py index 58c799228a..cf1630a883 100644 --- a/tests/unit/db_conftest.py +++ b/tests/unit/db_conftest.py @@ -507,11 +507,11 @@ def types_map() -> Dict[str, type]: trans = str.maketrans({ch: "_" for ch in " (),"}) struct_items = [f"{key.translate(trans)}_col {key}" for key in struct_keys] - struct_type = f"Struct({', '.join(struct_items)})" + struct_type = f"struct({', '.join(struct_items)})" struct_field_names = [f"{key.translate(trans)}_col" for key in struct_keys] struct = {struct_type: STRUCT(dict(zip(struct_field_names, struct_fields)))} nested_struct = { - f"Struct(s {struct_type} null)": STRUCT({"s": list(struct.values())[0]}) + f"struct(s {struct_type} null)": STRUCT({"s": list(struct.values())[0]}) } return { From 6ee95eb6b2a47f73468ba6a366a06d3f28e5ae1d Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Wed, 11 Dec 2024 15:54:30 +0200 Subject: [PATCH 4/5] improve typing --- src/firebolt/common/_types.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/firebolt/common/_types.py b/src/firebolt/common/_types.py index 943adc3ab9..513c30816e 100644 --- a/src/firebolt/common/_types.py +++ b/src/firebolt/common/_types.py @@ -111,6 +111,8 @@ def Binary(value: str) -> bytes: # NOSONAR class ExtendedType: """Base type for all extended types in Firebolt (array, decimal, struct, etc.).""" + __name__ = "ExtendedType" + @staticmethod def is_valid_type(type_: Any) -> bool: return type_ in _col_types or isinstance(type_, ExtendedType) @@ -125,7 +127,7 @@ class ARRAY(ExtendedType): __name__ = "Array" _prefix = "array(" - def __init__(self, subtype: Union[type, ARRAY, DECIMAL, STRUCT]): + def __init__(self, subtype: Union[type, ExtendedType]): assert self.is_valid_type(subtype), f"Invalid array subtype: {str(subtype)}" self.subtype = subtype @@ -161,7 +163,7 @@ class STRUCT(ExtendedType): __name__ = "Struct" _prefix = "struct(" - def __init__(self, fields: Dict[str, Union[type, ARRAY, DECIMAL, STRUCT]]): + def __init__(self, fields: Dict[str, Union[type, ExtendedType]]): for name, type_ in fields.items(): assert self.is_valid_type(type_), f"Invalid struct field type: {str(type_)}" self.fields = fields @@ -247,7 +249,7 @@ def split_struct_fields(raw_struct: str) -> List[str]: return res -def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL, STRUCT]: # noqa: C901 +def parse_type(raw_type: str) -> Union[type, ExtendedType]: # noqa: C901 """Parse typename provided by query metadata into Python type.""" if not isinstance(raw_type, str): raise DataError(f"Invalid typename {str(raw_type)}: str expected") @@ -298,7 +300,7 @@ def _parse_bytea(str_value: str) -> bytes: def parse_value( value: RawColType, - ctype: Union[type, ARRAY, DECIMAL, STRUCT], + ctype: Union[type, ExtendedType], ) -> ColType: """Provided raw value, and Python type; parses first into Python value.""" if value is None: From 1312534f310b066c9eeb9bedd674b40c2130b1d1 Mon Sep 17 00:00:00 2001 From: Stepan Burlakov Date: Thu, 12 Dec 2024 15:51:31 +0200 Subject: [PATCH 5/5] resolve comments --- src/firebolt/common/_types.py | 23 ++++++++++++++++------- tests/unit/common/test_typing_parse.py | 21 +++++++++++++++++++++ 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/src/firebolt/common/_types.py b/src/firebolt/common/_types.py index 513c30816e..1a4d27ab5e 100644 --- a/src/firebolt/common/_types.py +++ b/src/firebolt/common/_types.py @@ -128,7 +128,8 @@ class ARRAY(ExtendedType): _prefix = "array(" def __init__(self, subtype: Union[type, ExtendedType]): - assert self.is_valid_type(subtype), f"Invalid array subtype: {str(subtype)}" + if not self.is_valid_type(subtype): + raise ValueError(f"Invalid array subtype: {str(subtype)}") self.subtype = subtype def __str__(self) -> str: @@ -165,7 +166,8 @@ class STRUCT(ExtendedType): def __init__(self, fields: Dict[str, Union[type, ExtendedType]]): for name, type_ in fields.items(): - assert self.is_valid_type(type_), f"Invalid struct field type: {str(type_)}" + if not self.is_valid_type(type_): + raise ValueError(f"Invalid struct field type: {str(type_)}") self.fields = fields def __str__(self) -> str: @@ -230,7 +232,11 @@ def python_type(self) -> type: def split_struct_fields(raw_struct: str) -> List[str]: - balance = 0 + """Split raw struct inner fields string into a list of field definitions. + >>> split_struct_fields("field1 int, field2 struct(field1 int, field2 text)") + ["field1 int", "field2 struct(field1 int, field2 text)"] + """ + balance = 0 # keep track of the level of nesting, and only split on level 0 separator = "," res = [] current = StringIO() @@ -306,7 +312,7 @@ def parse_value( if value is None: return None if ctype in (int, str, float): - assert isinstance(ctype, type) + assert isinstance(ctype, type) # assertion for mypy return ctype(value) if ctype is date: if not isinstance(value, str): @@ -326,13 +332,16 @@ def parse_value( raise DataError(f"Invalid bytea value {value}: str expected") return _parse_bytea(value) if isinstance(ctype, DECIMAL): - assert isinstance(value, (str, int)) + if not isinstance(value, (str, int)): + raise DataError(f"Invalid decimal value {value}: str or int expected") return Decimal(value) if isinstance(ctype, ARRAY): - assert isinstance(value, list) + if not isinstance(value, list): + raise DataError(f"Invalid array value {value}: list expected") return [parse_value(it, ctype.subtype) for it in value] if isinstance(ctype, STRUCT): - assert isinstance(value, dict) + if not isinstance(value, dict): + raise DataError(f"Invalid struct value {value}: dict expected") return { name: parse_value(value.get(name), type_) for name, type_ in ctype.fields.items() diff --git a/tests/unit/common/test_typing_parse.py b/tests/unit/common/test_typing_parse.py index 971fccb216..c6ba2c501c 100644 --- a/tests/unit/common/test_typing_parse.py +++ b/tests/unit/common/test_typing_parse.py @@ -14,6 +14,7 @@ TimestampFromTicks, parse_type, parse_value, + split_struct_fields, ) from firebolt.utils.exception import DataError, NotSupportedError @@ -338,6 +339,12 @@ def test_parse_value_bytes(value, expected, error) -> None: STRUCT({"a": int, "s": STRUCT({"b": date})}), None, ), + ( + {"a": None, "b": None}, + {"a": None, "b": None}, + STRUCT({"a": int, "b": bool}), + None, + ), (None, None, STRUCT({"a": int, "b": bool}), None), ({"a": [1, 2, 3]}, {"a": [1, 2, 3]}, STRUCT({"a": ARRAY(int)}), None), ], @@ -351,3 +358,17 @@ def test_parse_value_struct(value, expected, type_, error) -> None: assert ( parse_value(value, type_) == expected ), f"Error parsing struct: provided {value}" + + +@mark.parametrize( + "value,expected", + [ + ("a int, b text", ["a int", " b text"]), + ("a int, s struct(a int, b text)", ["a int", " s struct(a int, b text)"]), + ("a int, b array(struct(a int))", ["a int", " b array(struct(a int))"]), + ], +) +def test_split_struct_fields(value, expected) -> None: + assert ( + split_struct_fields(value) == expected + ), f"Error splitting struct fields: provided {value}, expected {expected}"