Skip to content

Commit

Permalink
fix: handle temporal columns in presto partitions
Browse files Browse the repository at this point in the history
The where_latest_partition_date method incorrectly handled column types
as strings, but they're provided as SQLA types instead.

Deal with the DATE and TIMESTAMP cases, which were being incorrectly
rendered in the query as a result of the above, and causing table
preview queries to fail.
  • Loading branch information
giftig committed May 19, 2023
1 parent d583ca9 commit d67b11e
Show file tree
Hide file tree
Showing 4 changed files with 50 additions and 11 deletions.
2 changes: 1 addition & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1276,7 +1276,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments,unused-argumen
schema: Optional[str],
database: Database,
query: Select,
columns: Optional[List[Dict[str, str]]] = None,
columns: Optional[List[Dict[str, Any]]] = None,
) -> Optional[Select]:
"""
Add a where clause to a query to reference only the most recent partition
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments
schema: Optional[str],
database: "Database",
query: Select,
columns: Optional[List[Dict[str, str]]] = None,
columns: Optional[List[Dict[str, Any]]] = None,
) -> Optional[Select]:
try:
col_names, values = cls.latest_partition(
Expand Down
18 changes: 10 additions & 8 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments
schema: Optional[str],
database: Database,
query: Select,
columns: Optional[List[Dict[str, str]]] = None,
columns: Optional[List[Dict[str, Any]]] = None,
) -> Optional[Select]:
try:
col_names, values = cls.latest_partition(
Expand All @@ -513,13 +513,15 @@ def where_latest_partition( # pylint: disable=too-many-arguments
}

for col_name, value in zip(col_names, values):
if col_name in column_type_by_name:
if column_type_by_name.get(col_name) == "TIMESTAMP":
query = query.where(Column(col_name, TimeStamp()) == value)
elif column_type_by_name.get(col_name) == "DATE":
query = query.where(Column(col_name, Date()) == value)
else:
query = query.where(Column(col_name) == value)
col_type = column_type_by_name.get(col_name)

if isinstance(col_type, types.DATE):
col_type = Date()
elif isinstance(col_type, types.TIMESTAMP):
col_type = TimeStamp()

query = query.where(Column(col_name, col_type) == value)

return query

@classmethod
Expand Down
39 changes: 38 additions & 1 deletion tests/unit_tests/db_engine_specs/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional, Type
from unittest import mock

import pytest
import pytz
from sqlalchemy import types
from pyhive.sqlalchemy_presto import PrestoDialect
from sqlalchemy import sql, text, types
from sqlalchemy.engine.url import make_url

from superset.utils.core import GenericDataType
Expand Down Expand Up @@ -106,3 +108,38 @@ def test_get_schema_from_engine_params() -> None:
)
is None
)


@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition")
@pytest.mark.parametrize(
["column_type", "column_value", "expected_value"],
[
(types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
(types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
(types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
(types.INT(), 1234, "1234"),
],
)
def test_where_latest_partition(
mock_latest_partition, column_type, column_value: Any, expected_value: str
) -> None:
"""
Test the ``where_latest_partition`` method
"""
from superset.db_engine_specs.presto import PrestoEngineSpec as spec

mock_latest_partition.return_value = (["partition_key"], [column_value])

query = sql.select(text("* FROM table"))
columns = [{"name": "partition_key", "type": column_type}]

expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
result = spec.where_latest_partition(
"table", mock.MagicMock(), mock.MagicMock(), query, columns
)
assert result is not None
actual = result.compile(
dialect=PrestoDialect(), compile_kwargs={"literal_binds": True}
)

assert str(actual) == expected

0 comments on commit d67b11e

Please sign in to comment.