Skip to content

Commit

Permalink
fix(spark)!: Make TIMESTAMP map to Type.TIMESTAMPTZ (#4451)
Browse files Browse the repository at this point in the history
* fix(spark): Make TIMESTAMP map to Type.TIMESTAMPTZ

* Move token def to Spark2
  • Loading branch information
VaggelisD authored Nov 28, 2024
1 parent ca5023d commit 07fa69d
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 14 deletions.
12 changes: 9 additions & 3 deletions sqlglot/dialects/spark2.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from sqlglot.dialects.hive import Hive
from sqlglot.helper import seq_get, ensure_list
from sqlglot.tokens import TokenType
from sqlglot.transforms import (
preprocess,
remove_unique_constraints,
Expand Down Expand Up @@ -159,6 +160,14 @@ class Spark2(Hive):
),
}

class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'"), ("x'", "'")]

KEYWORDS = {
**Hive.Tokenizer.KEYWORDS,
"TIMESTAMP": TokenType.TIMESTAMPTZ,
}

class Parser(Hive.Parser):
TRIM_PATTERN_FIRST = True

Expand Down Expand Up @@ -337,6 +346,3 @@ def columndef_sql(self, expression: exp.ColumnDef, sep: str = " ") -> str:
else sep
),
)

class Tokenizer(Hive.Tokenizer):
HEX_STRINGS = [("X'", "'"), ("x'", "'")]
3 changes: 2 additions & 1 deletion tests/dialects/test_databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ def test_databricks(self):
"CREATE TABLE IF NOT EXISTS db.table (a TIMESTAMP, b BOOLEAN GENERATED ALWAYS AS (NOT a IS NULL)) USING DELTA"
)
self.validate_identity(
"SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t"
"SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t",
"SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t",
)
self.validate_identity(
"SELECT * FROM sales UNPIVOT INCLUDE NULLS (sales FOR quarter IN (q1 AS `Jan-Mar`))"
Expand Down
8 changes: 4 additions & 4 deletions tests/dialects/test_hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -761,13 +761,13 @@ def test_hive(self):
},
)
self.validate_all(
"SELECT TRUNC(CAST(ds AS TIMESTAMP), 'MONTH') AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'",
"SELECT TRUNC(CAST(ds AS TIMESTAMP), 'MONTH')",
read={
"hive": "SELECT TRUNC(CAST(ds AS TIMESTAMP), 'MONTH') AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'",
"presto": "SELECT DATE_TRUNC('MONTH', CAST(ds AS TIMESTAMP)) AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'",
"hive": "SELECT TRUNC(CAST(ds AS TIMESTAMP), 'MONTH')",
"presto": "SELECT DATE_TRUNC('MONTH', CAST(ds AS TIMESTAMP))",
},
write={
"presto": "SELECT DATE_TRUNC('MONTH', TRY_CAST(ds AS TIMESTAMP)) AS mm FROM tbl WHERE ds BETWEEN '2023-10-01' AND '2024-02-29'",
"presto": "SELECT DATE_TRUNC('MONTH', TRY_CAST(ds AS TIMESTAMP))",
},
)
self.validate_all(
Expand Down
13 changes: 9 additions & 4 deletions tests/dialects/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,22 +1273,27 @@ def test_safe_div(self):
)

def test_timestamp_trunc(self):
for dialect in ("postgres", "snowflake", "duckdb", "spark", "databricks"):
hive_dialects = ("spark", "databricks")
for dialect in ("postgres", "snowflake", "duckdb", *hive_dialects):
for unit in (
"MILLISECOND",
"SECOND",
"DAY",
"MONTH",
"YEAR",
):
with self.subTest(f"MySQL -> {dialect} Timestamp Trunc with unit {unit}: "):
cast = (
"TIMESTAMP('2001-02-16 20:38:40')"
if dialect in hive_dialects
else "CAST('2001-02-16 20:38:40' AS DATETIME)"
)
self.validate_all(
f"DATE_ADD('0000-01-01 00:00:00', INTERVAL (TIMESTAMPDIFF({unit}, '0000-01-01 00:00:00', CAST('2001-02-16 20:38:40' AS DATETIME))) {unit})",
f"DATE_ADD('0000-01-01 00:00:00', INTERVAL (TIMESTAMPDIFF({unit}, '0000-01-01 00:00:00', {cast})) {unit})",
read={
dialect: f"DATE_TRUNC({unit}, TIMESTAMP '2001-02-16 20:38:40')",
},
write={
"mysql": f"DATE_ADD('0000-01-01 00:00:00', INTERVAL (TIMESTAMPDIFF({unit}, '0000-01-01 00:00:00', CAST('2001-02-16 20:38:40' AS DATETIME))) {unit})",
"mysql": f"DATE_ADD('0000-01-01 00:00:00', INTERVAL (TIMESTAMPDIFF({unit}, '0000-01-01 00:00:00', {cast})) {unit})",
},
)

Expand Down
4 changes: 2 additions & 2 deletions tests/dialects/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@ def test_time(self):
},
)
self.validate_all(
"((DAY_OF_WEEK(CAST(TRY_CAST('2012-08-08 01:00:00' AS TIMESTAMP) AS DATE)) % 7) + 1)",
"((DAY_OF_WEEK(CAST(CAST(TRY_CAST('2012-08-08 01:00:00' AS TIMESTAMP WITH TIME ZONE) AS TIMESTAMP) AS DATE)) % 7) + 1)",
read={
"spark": "DAYOFWEEK(CAST('2012-08-08 01:00:00' AS TIMESTAMP))",
},
Expand Down Expand Up @@ -406,7 +406,7 @@ def test_time(self):
},
)
self.validate_all(
"SELECT AT_TIMEZONE(CAST('2012-10-31 00:00' AS TIMESTAMP), 'America/Sao_Paulo')",
"SELECT AT_TIMEZONE(CAST(CAST('2012-10-31 00:00' AS TIMESTAMP WITH TIME ZONE) AS TIMESTAMP), 'America/Sao_Paulo')",
read={
"spark": "SELECT FROM_UTC_TIMESTAMP(TIMESTAMP '2012-10-31 00:00', 'America/Sao_Paulo')",
},
Expand Down
10 changes: 10 additions & 0 deletions tests/dialects/test_spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,6 +765,16 @@ def test_spark(self):
},
)

self.validate_all(
"SELECT CAST(col AS TIMESTAMP)",
write={
"spark2": "SELECT CAST(col AS TIMESTAMP)",
"spark": "SELECT CAST(col AS TIMESTAMP)",
"databricks": "SELECT TRY_CAST(col AS TIMESTAMP)",
"duckdb": "SELECT TRY_CAST(col AS TIMESTAMPTZ)",
},
)

def test_bool_or(self):
self.validate_all(
"SELECT a, LOGICAL_OR(b) FROM table GROUP BY a",
Expand Down

0 comments on commit 07fa69d

Please sign in to comment.