Skip to content

Commit

Permalink
Merge 1312534 into 28435b9
Browse files Browse the repository at this point in the history
  • Loading branch information
stepansergeevitch authored Dec 12, 2024
2 parents 28435b9 + 1312534 commit c7526ea
Show file tree
Hide file tree
Showing 8 changed files with 292 additions and 25 deletions.
1 change: 1 addition & 0 deletions src/firebolt/async_db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
NUMBER,
ROWID,
STRING,
STRUCT,
Binary,
Date,
DateFromTicks,
Expand Down
106 changes: 87 additions & 19 deletions src/firebolt/common/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from datetime import date, datetime, timezone
from decimal import Decimal
from enum import Enum
from typing import List, Optional, Sequence, Union
from io import StringIO
from typing import Any, Dict, List, Optional, Sequence, Union

from sqlparse import parse as parse_sql # type: ignore
from sqlparse.sql import ( # type: ignore
Expand Down Expand Up @@ -62,8 +63,6 @@ def parse_datetime(datetime_string: str) -> datetime:
# These definitions are required by PEP-249
Date = date

_AccountInfo = namedtuple("_AccountInfo", ["id", "version"])


def DateFromTicks(t: int) -> date: # NOSONAR
"""Convert `ticks` to `date` for Firebolt DB."""
Expand Down Expand Up @@ -109,16 +108,28 @@ def Binary(value: str) -> bytes: # NOSONAR
)


class ARRAY:
class ExtendedType:
"""Base type for all extended types in Firebolt (array, decimal, struct, etc.)."""

__name__ = "ExtendedType"

@staticmethod
def is_valid_type(type_: Any) -> bool:
return type_ in _col_types or isinstance(type_, ExtendedType)

def __hash__(self) -> int:
return hash(str(self))


class ARRAY(ExtendedType):
"""Class for holding `array` column type information in Firebolt DB."""

__name__ = "Array"
_prefix = "array("

def __init__(self, subtype: Union[type, ARRAY, DECIMAL]):
assert (subtype in _col_types and subtype is not list) or isinstance(
subtype, (ARRAY, DECIMAL)
), f"Invalid array subtype: {str(subtype)}"
def __init__(self, subtype: Union[type, ExtendedType]):
if not self.is_valid_type(subtype):
raise ValueError(f"Invalid array subtype: {str(subtype)}")
self.subtype = subtype

def __str__(self) -> str:
Expand All @@ -130,7 +141,7 @@ def __eq__(self, other: object) -> bool:
return other.subtype == self.subtype


class DECIMAL:
class DECIMAL(ExtendedType):
"""Class for holding `decimal` value information in Firebolt DB."""

__name__ = "Decimal"
Expand All @@ -143,15 +154,29 @@ def __init__(self, precision: int, scale: int):
def __str__(self) -> str:
return f"Decimal({self.precision}, {self.scale})"

def __hash__(self) -> int:
return hash(str(self))

def __eq__(self, other: object) -> bool:
if not isinstance(other, DECIMAL):
return NotImplemented
return other.precision == self.precision and other.scale == self.scale


class STRUCT(ExtendedType):
__name__ = "Struct"
_prefix = "struct("

def __init__(self, fields: Dict[str, Union[type, ExtendedType]]):
for name, type_ in fields.items():
if not self.is_valid_type(type_):
raise ValueError(f"Invalid struct field type: {str(type_)}")
self.fields = fields

def __str__(self) -> str:
return f"Struct({', '.join(f'{k}: {v}' for k, v in self.fields.items())})"

def __eq__(self, other: Any) -> bool:
return isinstance(other, STRUCT) and other.fields == self.fields


NULLABLE_SUFFIX = "null"


Expand Down Expand Up @@ -206,7 +231,31 @@ def python_type(self) -> type:
return types[self]


def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL]: # noqa: C901
def split_struct_fields(raw_struct: str) -> List[str]:
"""Split raw struct inner fields string into a list of field definitions.
>>> split_struct_fields("field1 int, field2 struct(field1 int, field2 text)")
["field1 int", "field2 struct(field1 int, field2 text)"]
"""
balance = 0 # keep track of the level of nesting, and only split on level 0
separator = ","
res = []
current = StringIO()
for i, ch in enumerate(raw_struct):
if ch == "(":
balance += 1
elif ch == ")":
balance -= 1
elif ch == separator and balance == 0:
res.append(current.getvalue())
current = StringIO()
continue
current.write(ch)

res.append(current.getvalue())
return res


def parse_type(raw_type: str) -> Union[type, ExtendedType]: # noqa: C901
"""Parse typename provided by query metadata into Python type."""
if not isinstance(raw_type, str):
raise DataError(f"Invalid typename {str(raw_type)}: str expected")
Expand All @@ -218,10 +267,20 @@ def parse_type(raw_type: str) -> Union[type, ARRAY, DECIMAL]: # noqa: C901
try:
prec_scale = raw_type[len(DECIMAL._prefix) : -1].split(",")
precision, scale = int(prec_scale[0]), int(prec_scale[1])
return DECIMAL(precision, scale)
except (ValueError, IndexError):
pass
else:
return DECIMAL(precision, scale)
# Handle structs
if raw_type.startswith(STRUCT._prefix) and raw_type.endswith(")"):
try:
fields_raw = split_struct_fields(raw_type[len(STRUCT._prefix) : -1])
fields = {}
for f in fields_raw:
name, type_ = f.strip().split(" ", 1)
fields[name.strip()] = parse_type(type_.strip())
return STRUCT(fields)
except ValueError:
pass
# Handle nullable
if raw_type.endswith(NULLABLE_SUFFIX):
return parse_type(raw_type[: -len(NULLABLE_SUFFIX)].strip(" "))
Expand All @@ -247,13 +306,13 @@ def _parse_bytea(str_value: str) -> bytes:

def parse_value(
value: RawColType,
ctype: Union[type, ARRAY, DECIMAL],
ctype: Union[type, ExtendedType],
) -> ColType:
"""Provided raw value, and Python type; parses first into Python value."""
if value is None:
return None
if ctype in (int, str, float):
assert isinstance(ctype, type)
assert isinstance(ctype, type) # assertion for mypy
return ctype(value)
if ctype is date:
if not isinstance(value, str):
Expand All @@ -273,11 +332,20 @@ def parse_value(
raise DataError(f"Invalid bytea value {value}: str expected")
return _parse_bytea(value)
if isinstance(ctype, DECIMAL):
assert isinstance(value, (str, int))
if not isinstance(value, (str, int)):
raise DataError(f"Invalid decimal value {value}: str or int expected")
return Decimal(value)
if isinstance(ctype, ARRAY):
assert isinstance(value, list)
if not isinstance(value, list):
raise DataError(f"Invalid array value {value}: list expected")
return [parse_value(it, ctype.subtype) for it in value]
if isinstance(ctype, STRUCT):
if not isinstance(value, dict):
raise DataError(f"Invalid struct value {value}: dict expected")
return {
name: parse_value(value.get(name), type_)
for name, type_ in ctype.fields.items()
}
raise DataError(f"Unsupported data type returned: {ctype.__name__}")


Expand Down
1 change: 1 addition & 0 deletions src/firebolt/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
NUMBER,
ROWID,
STRING,
STRUCT,
Binary,
Date,
DateFromTicks,
Expand Down
26 changes: 26 additions & 0 deletions tests/integration/dbapi/async/V2/test_queries_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,3 +429,29 @@ async def test_select_geography(
select_geography_response,
"Invalid data returned by fetchall",
)


async def test_select_struct(
connection: Connection,
setup_struct_query: str,
cleanup_struct_query: str,
select_struct_query: str,
select_struct_description: List[Column],
select_struct_response: List[ColType],
):
with connection.cursor() as c:
try:
await c.execute(setup_struct_query)
await c.execute(select_struct_query)
assert (
c.description == select_struct_description
), "Invalid description value"
res = await c.fetchall()
assert len(res) == 1, "Invalid data length"
assert_deep_eq(
res,
select_struct_response,
"Invalid data returned by fetchall",
)
finally:
await c.execute(cleanup_struct_query)
53 changes: 52 additions & 1 deletion tests/integration/dbapi/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytest import fixture

from firebolt.async_db.cursor import Column
from firebolt.common._types import ColType
from firebolt.common._types import STRUCT, ColType
from firebolt.db import ARRAY, DECIMAL, Connection

LOGGER = getLogger(__name__)
Expand Down Expand Up @@ -209,3 +209,54 @@ def select_geography_description() -> List[Column]:
@fixture
def select_geography_response() -> List[ColType]:
return [["0101000020E6100000FEFFFFFFFFFFEF3F000000000000F03F"]]


@fixture
def setup_struct_query() -> str:
return """
SET advanced_mode=1;
SET enable_struct=1;
SET enable_create_table_v2=true;
SET enable_row_selection=true;
SET prevent_create_on_information_schema=true;
SET enable_create_table_with_struct_type=true;
DROP TABLE IF EXISTS test_struct;
DROP TABLE IF EXISTS test_struct_helper;
CREATE TABLE IF NOT EXISTS test_struct(id int not null, s struct(a array(int) not null, b datetime null) not null);
CREATE TABLE IF NOT EXISTS test_struct_helper(a array(int) not null, b datetime null);
INSERT INTO test_struct_helper(a, b) VALUES ([1, 2], '2019-07-31 01:01:01');
INSERT INTO test_struct(id, s) SELECT 1, test_struct_helper FROM test_struct_helper;
"""


@fixture
def cleanup_struct_query() -> str:
return """
DROP TABLE IF EXISTS test_struct;
DROP TABLE IF EXISTS test_struct_helper;
"""


@fixture
def select_struct_query() -> str:
return "SELECT test_struct FROM test_struct"


@fixture
def select_struct_description() -> List[Column]:
return [
Column(
"test_struct",
STRUCT({"id": int, "s": STRUCT({"a": ARRAY(int), "b": datetime})}),
None,
None,
None,
None,
None,
)
]


@fixture
def select_struct_response() -> List[ColType]:
return [[{"id": 1, "s": {"a": [1, 2], "b": datetime(2019, 7, 31, 1, 1, 1)}}]]
26 changes: 26 additions & 0 deletions tests/integration/dbapi/sync/V2/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,3 +512,29 @@ def test_select_geography(
select_geography_response,
"Invalid data returned by fetchall",
)


def test_select_struct(
connection: Connection,
setup_struct_query: str,
cleanup_struct_query: str,
select_struct_query: str,
select_struct_description: List[Column],
select_struct_response: List[ColType],
):
with connection.cursor() as c:
try:
c.execute(setup_struct_query)
c.execute(select_struct_query)
assert (
c.description == select_struct_description
), "Invalid description value"
res = c.fetchall()
assert len(res) == 1, "Invalid data length"
assert_deep_eq(
res,
select_struct_response,
"Invalid data returned by fetchall",
)
finally:
c.execute(cleanup_struct_query)
Loading

0 comments on commit c7526ea

Please sign in to comment.