Skip to content

Commit

Permalink
Fix!: use maybe_parse in exp.to_table, fix exp.Table expression parser (
Browse files Browse the repository at this point in the history
tobymao#1684)

* Fix: allow >3 table parts in exp.to_table

* Fix mypy types

* Keep parity between to_table and _parse_table by using maybe_parse

* Refactor bigquery _parse_table_parts

* Fix exp.Table expression parser

* Fixup
  • Loading branch information
georgesittas authored and adrianisk committed Jun 21, 2023
1 parent e84cc53 commit a68921e
Show file tree
Hide file tree
Showing 7 changed files with 43 additions and 16 deletions.
15 changes: 12 additions & 3 deletions sqlglot/dialects/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
timestrtotime_sql,
ts_or_ds_to_date_sql,
)
from sqlglot.helper import seq_get
from sqlglot.helper import seq_get, split_num_words
from sqlglot.tokens import TokenType

E = t.TypeVar("E", bound=exp.Expression)
Expand Down Expand Up @@ -230,10 +230,19 @@ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:

return this

def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
table = super()._parse_table_parts(schema=schema)
if isinstance(table.this, exp.Identifier) and "." in table.name:
table = exp.to_table(table.name, dialect="bigquery")
catalog, db, this, *rest = (
t.cast(t.Optional[exp.Expression], exp.to_identifier(x))
for x in split_num_words(table.name, ".", 3)
)

if rest and this:
this = exp.Dot.build(t.cast(t.List[exp.Expression], [this, *rest]))

table = exp.Table(this=this, db=db, catalog=catalog)

return table

class Generator(generator.Generator):
Expand Down
2 changes: 1 addition & 1 deletion sqlglot/dialects/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ def _parse_system_time(self) -> t.Optional[exp.Expression]:

return system_time

def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
table = super()._parse_table_parts(schema=schema)
table.set("system_time", self._parse_system_time())
return table
Expand Down
17 changes: 12 additions & 5 deletions sqlglot/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
ensure_collection,
ensure_list,
seq_get,
split_num_words,
subclasses,
)
from sqlglot.tokens import Token
Expand Down Expand Up @@ -2196,7 +2195,7 @@ def catalog(self) -> str:

@property
def parts(self) -> t.List[Identifier]:
"""Return the parts of a column in order catalog, db, table."""
"""Return the parts of a table in order catalog, db, table."""
return [
t.cast(Identifier, self.args[part])
for part in ("catalog", "db", "this")
Expand Down Expand Up @@ -5030,13 +5029,17 @@ def to_table(sql_path: None, **kwargs) -> None:
...


def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
def to_table(
sql_path: t.Optional[str | Table], dialect: DialectType = None, **kwargs
) -> t.Optional[Table]:
"""
Create a table expression from a `[catalog].[schema].[table]` sql path. Catalog and schema are optional.
If a table is passed in then that table is returned.
Args:
sql_path: a `[catalog].[schema].[table]` string.
dialect: the source dialect according to which the table name will be parsed.
kwargs: the kwargs to instantiate the resulting `Table` expression with.
Returns:
A table expression.
Expand All @@ -5046,8 +5049,12 @@ def to_table(sql_path: t.Optional[str | Table], **kwargs) -> t.Optional[Table]:
if not isinstance(sql_path, str):
raise ValueError(f"Invalid type provided for a table: {type(sql_path)}")

catalog, db, table_name = (to_identifier(x) for x in split_num_words(sql_path, ".", 3))
return Table(this=table_name, db=db, catalog=catalog, **kwargs)
table = maybe_parse(sql_path, into=Table, dialect=dialect)
if table:
for k, v in kwargs.items():
table.set(k, v)

return table


def to_column(sql_path: str | Column, **kwargs) -> Column:
Expand Down
6 changes: 3 additions & 3 deletions sqlglot/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ class Parser(metaclass=_Parser):
exp.Limit: lambda self: self._parse_limit(),
exp.Offset: lambda self: self._parse_offset(),
exp.TableAlias: lambda self: self._parse_table_alias(),
exp.Table: lambda self: self._parse_table(),
exp.Table: lambda self: self._parse_table_parts(),
exp.Condition: lambda self: self._parse_conjunction(),
exp.Expression: lambda self: self._parse_statement(),
exp.Properties: lambda self: self._parse_properties(),
Expand Down Expand Up @@ -2227,7 +2227,7 @@ def _parse_table_part(self, schema: bool = False) -> t.Optional[exp.Expression]:
or self._parse_placeholder()
)

def _parse_table_parts(self, schema: bool = False) -> exp.Expression:
def _parse_table_parts(self, schema: bool = False) -> exp.Table:
catalog = None
db = None
table = self._parse_table_part(schema=schema)
Expand Down Expand Up @@ -2271,7 +2271,7 @@ def _parse_table(
subquery.set("pivots", self._parse_pivots())
return subquery

this = self._parse_table_parts(schema=schema)
this: exp.Expression = self._parse_table_parts(schema=schema)

if schema:
return self._parse_schema(this=this)
Expand Down
6 changes: 6 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,12 @@ def test_bigquery(self):
"CREATE TEMP TABLE foo AS SELECT 1",
write={"bigquery": "CREATE TEMPORARY TABLE foo AS SELECT 1"},
)
self.validate_all(
"SELECT * FROM `SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW`",
write={
"bigquery": "SELECT * FROM SOME_PROJECT_ID.SOME_DATASET_ID.INFORMATION_SCHEMA.SOME_VIEW",
},
)
self.validate_all(
"SELECT * FROM `my-project.my-dataset.my-table`",
write={"bigquery": "SELECT * FROM `my-project`.`my-dataset`.`my-table`"},
Expand Down
4 changes: 0 additions & 4 deletions tests/test_expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,10 +726,6 @@ def test_to_table(self):
self.assertEqual(catalog_db_and_table.args.get("catalog"), exp.to_identifier("catalog"))
with self.assertRaises(ValueError):
exp.to_table(1)
empty_string = exp.to_table("")
self.assertEqual(empty_string.name, "")
self.assertIsNone(table_only.args.get("db"))
self.assertIsNone(table_only.args.get("catalog"))

def test_to_column(self):
column_only = exp.to_column("column_name")
Expand Down
9 changes: 9 additions & 0 deletions tests/test_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ def test_parse_into(self):
self.assertIsInstance(parse_one("left join foo", into=exp.Join), exp.Join)
self.assertIsInstance(parse_one("int", into=exp.DataType), exp.DataType)
self.assertIsInstance(parse_one("array<int>", into=exp.DataType), exp.DataType)
self.assertIsInstance(parse_one("foo", into=exp.Table), exp.Table)

with self.assertRaises(ParseError) as ctx:
parse_one("SELECT * FROM tbl", into=exp.Table)

self.assertEqual(
str(ctx.exception),
"Failed to parse into <class 'sqlglot.expressions.Table'>",
)

def test_parse_into_error(self):
expected_message = "Failed to parse into [<class 'sqlglot.expressions.From'>]"
Expand Down

0 comments on commit a68921e

Please sign in to comment.