Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bump sqlglot from 21.0.0 to 21.1.2 #137

Merged
merged 9 commits into from
Feb 22, 2024
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ classifiers = [
]
dependencies = [
"databricks-sdk>=0.18,<0.21",
"sqlglot==21.0.0",
"sqlglot==21.1.2",
"databricks-labs-blueprint~=0.1.0"
]

Expand Down
2 changes: 1 addition & 1 deletion src/databricks/labs/remorph/snow/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def _format_create_sql(self, expression: exp.Create) -> str:
if expression.args.get(arg_to_delete):
del expression.args[arg_to_delete]

return hive._create_sql(self, expression)
return self.create_sql(expression)


def _curr_time():
Expand Down
14 changes: 8 additions & 6 deletions src/databricks/labs/remorph/snow/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import typing as t
from typing import ClassVar

from sqlglot import exp, parser
from sqlglot.dialects.dialect import parse_date_delta
from sqlglot import exp
from sqlglot.dialects.dialect import build_date_delta as parse_date_delta
from sqlglot.dialects.snowflake import Snowflake
from sqlglot.dialects.snowflake import _parse_to_timestamp as parse_to_timestamp
from sqlglot.dialects.snowflake import _build_to_timestamp as parse_to_timestamp
from sqlglot.errors import ParseError
from sqlglot.helper import seq_get
from sqlglot.parser import build_var_map as parse_var_map
from sqlglot.tokens import Token, TokenType
from sqlglot.trie import new_trie

Expand Down Expand Up @@ -141,7 +142,7 @@ def _parse_monthname(args: list) -> local_expression.DateFormat:


def _parse_object_construct(args: list) -> exp.StarMap | exp.Struct:
expression = parser.parse_var_map(args)
expression = parse_var_map(args)

if isinstance(expression, exp.StarMap):
return exp.Struct(expressions=[expression.this])
Expand Down Expand Up @@ -422,7 +423,8 @@ def _get_table_alias(self):
"""
if self_copy._index + 2 < len(self_copy._tokens):
self_copy._advance(2)
table_alias = self_copy._curr.text
if self_copy._curr.text != ")":
table_alias = self_copy._curr.text
"""
* if the table is of format :: `<DB>.<TABLE>` <TABLE_ALIAS>, advance to two more tokens
- Handles (`SELECT .... FROM dwh.vw_replacement_customer d` => returns d)
Expand Down Expand Up @@ -470,7 +472,7 @@ def _json_column_op(self, this, path):
return self.expression(local_expression.Bracket, this=this, expressions=[path])
elif isinstance(this, local_expression.Bracket) and (is_name_value or is_table_alias):
return self.expression(local_expression.Bracket, this=this, expressions=[path])
elif isinstance(path, exp.Literal) and (path or is_path_value):
elif (isinstance(path, exp.Column)) and (path or is_path_value):
return self.expression(local_expression.Bracket, this=this, expressions=[path])
else:
return self.expression(exp.Bracket, this=this, expressions=[path])
27 changes: 25 additions & 2 deletions tests/unit/snow/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def test_lateral_struct(dialect_context):
)
validate_source_transpile(
databricks_sql="""SELECT tt.id, FROM_JSON(tt.details, {TT.DETAILS_SCHEMA}) FROM prod.public.table AS tt
LATERAL VIEW EXPLODE(FROM_JSON(FROM_JSON(tt.resp)['items'])) AS lit
LATERAL VIEW EXPLODE(FROM_JSON(lit.value['details'])) AS ltd""",
LATERAL VIEW EXPLODE(FROM_JSON(FROM_JSON(tt.resp)[items])) AS lit
LATERAL VIEW EXPLODE(FROM_JSON(lit.value["details"])) AS ltd""",
source={
"snowflake": """
SELECT
Expand All @@ -150,6 +150,12 @@ def test_lateral_struct(dialect_context):
, LATERAL FLATTEN (input=> parse_json(lit.value:"details")) AS ltd;""",
},
)
validate_source_transpile(
"SELECT level_1_key.level_2_key['1'] FROM demo1",
source={
"snowflake": "SELECT level_1_key:level_2_key:'1' FROM demo1;",
},
)


def test_datediff(dialect_context):
Expand Down Expand Up @@ -609,6 +615,14 @@ def test_monthname(dialect_context):
"snowflake": """SELECT MONTHNAME('2015-03-04') AS MON;""",
},
)
with pytest.raises(ParseError):
# Snowflake expects only 1 argument for MonthName function
validate_source_transpile(
"""SELECT DATE_FORMAT('2015-04-03 10:00', 'MMM') AS MONTH""",
source={
"snowflake": """SELECT MONTHNAME('2015-04-03 10:00', 'MMM') AS MONTH;""",
},
)

with pytest.raises(ParseError):
validate_source_transpile(
Expand Down Expand Up @@ -3022,6 +3036,15 @@ def test_to_number(dialect_context):
TO_NUMBER(sm.col2, '$99.00', 15, 5) AS col2 FROM sales_reports sm""",
},
)
with pytest.raises(UnsupportedError):
# Test case to validate `TO_NUMBER` parsing with precision and scale from table columns.
# Format is a mandatory argument in Databricks
validate_source_transpile(
"""SELECT CAST(TO_NUMBER(sm.col1) AS DECIMAL(15, 5)) AS col1 FROM sales_reports AS sm""",
source={
"snowflake": """SELECT TO_NUMERIC(sm.col1, 15, 5) AS col1 FROM sales_reports sm""",
},
)

with pytest.raises(UnsupportedError):
# Test case to validate `TO_DECIMAL` parsing without format
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/snow/test_sql_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,19 @@ def test_transpile_exception():
assert result == ""
assert error_list[1].file_name == "file.sql"
assert "Error Parsing args" in error_list[1].exception.args[0]


def test_tokenizer_exception():
error_list = [ParseError("", "")]
transpiler = SQLTranspiler("SNOWFLAKE", "1SELECT ~v\ud83d' ", "file.sql", error_list)
result = transpiler.transpile()
assert result == ""
assert error_list[1].file_name == "file.sql"
assert "Error tokenizing" in error_list[1].exception.args[0]


def test_procedure_conversion():
procedure_sql = "CREATE OR REPLACE PROCEDURE my_procedure() AS BEGIN SELECT * FROM my_table; END;"
transpiler = SQLTranspiler("SNOWFLAKE", procedure_sql, "file.sql", [])
result = transpiler.transpile()[0]
assert result == "CREATE PROCEDURE my_procedure(\n \n) AS BEGIN\nSELECT\n *\nFROM my_table"
Loading