Skip to content

Commit

Permalink
Feat(optimizer): improve struct type annotation support for EQ-delimi…
Browse files Browse the repository at this point in the history
…ted kv pairs (#2863)

* annotate-struct-pack

* annotate-struct-pack

* Update sqlglot/optimizer/annotate_types.py

Co-authored-by: Jo <[email protected]>

* Update sqlglot/optimizer/annotate_types.py

Co-authored-by: Jo <[email protected]>

* Apply suggestions from code review

Co-authored-by: Jo <[email protected]>

* Update sqlglot/optimizer/annotate_types.py

---------

Co-authored-by: Jo <[email protected]>
  • Loading branch information
fool1280 and georgesittas authored Jan 19, 2024
1 parent ad14f4e commit d5a08b8
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 17 deletions.
27 changes: 19 additions & 8 deletions sqlglot/optimizer/annotate_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,6 +479,20 @@ def _annotate_with_type(self, expression: E, target_type: exp.DataType.Type) ->
self._set_type(expression, target_type)
return self._annotate_args(expression)

@t.no_type_check
def _annotate_struct_value(
self, expression: exp.Expression
) -> t.Optional[exp.DataType] | exp.ColumnDef:
alias = expression.args.get("alias")
if alias:
return exp.ColumnDef(this=alias.copy(), kind=expression.type)

# Case: key = value or key := value
if expression.expression:
return exp.ColumnDef(this=expression.this.copy(), kind=expression.expression.type)

return expression.type

@t.no_type_check
def _annotate_by_args(
self,
Expand Down Expand Up @@ -516,16 +530,13 @@ def _annotate_by_args(
)

if struct:
expressions = [
expr.type
if not expr.args.get("alias")
else exp.ColumnDef(this=expr.args["alias"].copy(), kind=expr.type)
for expr in expressions
]

self._set_type(
expression,
exp.DataType(this=exp.DataType.Type.STRUCT, expressions=expressions, nested=True),
exp.DataType(
this=exp.DataType.Type.STRUCT,
expressions=[self._annotate_struct_value(expr) for expr in expressions],
nested=True,
),
)

return expression
Expand Down
26 changes: 17 additions & 9 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,18 +550,26 @@ def test_scope_warning(self, logger):

def test_struct_type_annotation(self):
tests = {
"SELECT STRUCT(1 AS col)": "STRUCT<col INT>",
"SELECT STRUCT(1 AS col, 2.5 AS row)": "STRUCT<col INT, row DOUBLE>",
"SELECT STRUCT(1)": "STRUCT<INT>",
"SELECT STRUCT(1 AS col, 2.5 AS row, struct(3.5 AS inner_col, 4 AS inner_row) AS nested_struct)": "STRUCT<col INT, row DOUBLE, nested_struct STRUCT<inner_col DOUBLE, inner_row INT>>",
"SELECT STRUCT(1 AS col, 2.5, ARRAY[1, 2, 3] AS nested_array, 'foo')": "STRUCT<col INT, DOUBLE, nested_array ARRAY<INT>, VARCHAR>",
"SELECT STRUCT(1, 2.5, 'bar')": "STRUCT<INT, DOUBLE, VARCHAR>",
'SELECT STRUCT(1 AS "CaseSensitive")': 'STRUCT<"CaseSensitive" INT>',
("SELECT STRUCT(1 AS col)", "spark"): "STRUCT<col INT>",
("SELECT STRUCT(1 AS col, 2.5 AS row)", "spark"): "STRUCT<col INT, row DOUBLE>",
("SELECT STRUCT(1)", "bigquery"): "STRUCT<INT>",
(
"SELECT STRUCT(1 AS col, 2.5 AS row, struct(3.5 AS inner_col, 4 AS inner_row) AS nested_struct)",
"spark",
): "STRUCT<col INT, row DOUBLE, nested_struct STRUCT<inner_col DOUBLE, inner_row INT>>",
(
"SELECT STRUCT(1 AS col, 2.5, ARRAY[1, 2, 3] AS nested_array, 'foo')",
"bigquery",
): "STRUCT<col INT, DOUBLE, nested_array ARRAY<INT>, VARCHAR>",
("SELECT STRUCT(1, 2.5, 'bar')", "spark"): "STRUCT<INT, DOUBLE, VARCHAR>",
('SELECT STRUCT(1 AS "CaseSensitive")', "spark"): 'STRUCT<"CaseSensitive" INT>',
("SELECT STRUCT_PACK(a := 1, b := 2.5)", "duckdb"): "STRUCT<a INT, b DOUBLE>",
("SELECT ROW(1, 2.5, 'foo')", "presto"): "STRUCT<INT, DOUBLE, VARCHAR>",
}

for sql, target_type in tests.items():
for (sql, dialect), target_type in tests.items():
with self.subTest(sql):
expression = annotate_types(parse_one(sql))
expression = annotate_types(parse_one(sql, read=dialect))
assert expression.expressions[0].is_type(target_type)

def test_literal_type_annotation(self):
Expand Down

0 comments on commit d5a08b8

Please sign in to comment.