Skip to content

Commit

Permalink
fix: handle temporal columns in presto partitions (#24054)
Browse files Browse the repository at this point in the history
  • Loading branch information
giftig authored and eschutho committed Jun 13, 2023
1 parent cfb4d27 commit 5f21e73
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 11 deletions.
2 changes: 1 addition & 1 deletion superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1168,7 +1168,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments,unused-argumen
schema: Optional[str],
database: Database,
query: Select,
columns: Optional[List[Dict[str, str]]] = None,
columns: Optional[List[Dict[str, Any]]] = None,
) -> Optional[Select]:
"""
Add a where clause to a query to reference only the most recent partition
Expand Down
2 changes: 1 addition & 1 deletion superset/db_engine_specs/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments
schema: Optional[str],
database: "Database",
query: Select,
columns: Optional[List[Dict[str, str]]] = None,
columns: Optional[List[Dict[str, Any]]] = None,
) -> Optional[Select]:
try:
col_names, values = cls.latest_partition(
Expand Down
18 changes: 10 additions & 8 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@ def where_latest_partition( # pylint: disable=too-many-arguments
schema: Optional[str],
database: Database,
query: Select,
columns: Optional[List[Dict[str, str]]] = None,
columns: Optional[List[Dict[str, Any]]] = None,
) -> Optional[Select]:
try:
col_names, values = cls.latest_partition(
Expand All @@ -480,13 +480,15 @@ def where_latest_partition( # pylint: disable=too-many-arguments
}

for col_name, value in zip(col_names, values):
if col_name in column_type_by_name:
if column_type_by_name.get(col_name) == "TIMESTAMP":
query = query.where(Column(col_name, TimeStamp()) == value)
elif column_type_by_name.get(col_name) == "DATE":
query = query.where(Column(col_name, Date()) == value)
else:
query = query.where(Column(col_name) == value)
col_type = column_type_by_name.get(col_name)

if isinstance(col_type, types.DATE):
col_type = Date()
elif isinstance(col_type, types.TIMESTAMP):
col_type = TimeStamp()

query = query.where(Column(col_name, col_type) == value)

return query

@classmethod
Expand Down
43 changes: 42 additions & 1 deletion tests/unit_tests/db_engine_specs/test_presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,13 @@
# under the License.
from datetime import datetime
from typing import Any, Dict, Optional, Type
from unittest import mock

import pytest
import pytz
from sqlalchemy import types
from pyhive.sqlalchemy_presto import PrestoDialect
from sqlalchemy import sql, text, types
from sqlalchemy.engine.url import make_url

from superset.utils.core import GenericDataType
from tests.unit_tests.db_engine_specs.utils import (
Expand Down Expand Up @@ -82,3 +85,41 @@ def test_get_column_spec(
from superset.db_engine_specs.presto import PrestoEngineSpec as spec

assert_column_spec(spec, native_type, sqla_type, attrs, generic_type, is_dttm)


@mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition")
@pytest.mark.parametrize(
["column_type", "column_value", "expected_value"],
[
(types.DATE(), "2023-05-01", "DATE '2023-05-01'"),
(types.TIMESTAMP(), "2023-05-01", "TIMESTAMP '2023-05-01'"),
(types.VARCHAR(), "2023-05-01", "'2023-05-01'"),
(types.INT(), 1234, "1234"),
],
)
def test_where_latest_partition(
mock_latest_partition: Any,
column_type: Any,
column_value: str,
expected_value: str,
) -> None:
"""
Test the ``where_latest_partition`` method
"""
from superset.db_engine_specs.presto import PrestoEngineSpec as spec

mock_latest_partition.return_value = (["partition_key"], [column_value])

query = sql.select(text("* FROM table"))
columns = [{"name": "partition_key", "type": column_type}]

expected = f"""SELECT * FROM table \nWHERE "partition_key" = {expected_value}"""
result = spec.where_latest_partition(
"table", mock.MagicMock(), mock.MagicMock(), query, columns
)
assert result is not None
actual = result.compile(
dialect=PrestoDialect(), compile_kwargs={"literal_binds": True}
)

assert str(actual) == expected

0 comments on commit 5f21e73

Please sign in to comment.