Skip to content

Commit

Permalink
fix(sqla): replace custom dttm type with literal_column (#19917)
Browse files Browse the repository at this point in the history
(cherry picked from commit 99f1f9e)
  • Loading branch information
villebro authored and michael-s-molina committed May 26, 2022
1 parent 2bd89d1 commit 9ca53b8
Show file tree
Hide file tree
Showing 14 changed files with 63 additions and 137 deletions.
31 changes: 23 additions & 8 deletions superset/connectors/base/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import json
from datetime import datetime
from enum import Enum
from typing import Any, Dict, Hashable, List, Optional, Set, Type, Union
from typing import Any, Dict, Hashable, List, Optional, Set, Type, TYPE_CHECKING, Union

from flask_appbuilder.security.sqla.models import User
from sqlalchemy import and_, Boolean, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import foreign, Query, relationship, RelationshipProperty, Session
from sqlalchemy.sql import literal_column

from superset import is_feature_enabled, security_manager
from superset.constants import EMPTY_STRING, NULL_STRING
Expand All @@ -33,6 +36,9 @@
from superset.utils import core as utils
from superset.utils.core import GenericDataType

if TYPE_CHECKING:
from superset.db_engine_specs.base import BaseEngineSpec

METRIC_FORM_DATA_PARAMS = [
"metric",
"metric_2",
Expand Down Expand Up @@ -387,33 +393,42 @@ def data_for_slices( # pylint: disable=too-many-locals
return data

@staticmethod
def filter_values_handler(
def filter_values_handler( # pylint: disable=too-many-arguments
values: Optional[FilterValues],
target_column_type: utils.GenericDataType,
target_generic_type: GenericDataType,
target_native_type: Optional[str] = None,
is_list_target: bool = False,
db_engine_spec: Optional[Type[BaseEngineSpec]] = None,
db_extra: Optional[Dict[str, Any]] = None,
) -> Optional[FilterValues]:
if values is None:
return None

def handle_single_value(value: Optional[FilterValue]) -> Optional[FilterValue]:
# backward compatibility with previous <select> components
if (
isinstance(value, (float, int))
and target_column_type == utils.GenericDataType.TEMPORAL
and target_generic_type == utils.GenericDataType.TEMPORAL
and target_native_type is not None
and db_engine_spec is not None
):
return datetime.utcfromtimestamp(value / 1000)
value = db_engine_spec.convert_dttm(
target_type=target_native_type,
dttm=datetime.utcfromtimestamp(value / 1000),
db_extra=db_extra,
)
value = literal_column(value)
if isinstance(value, str):
value = value.strip("\t\n")

if target_column_type == utils.GenericDataType.NUMERIC:
if target_generic_type == utils.GenericDataType.NUMERIC:
# For backwards compatibility and edge cases
# where a column data type might have changed
return utils.cast_to_num(value)
if value == NULL_STRING:
return None
if value == EMPTY_STRING:
return ""
if target_column_type == utils.GenericDataType.BOOLEAN:
if target_generic_type == utils.GenericDataType.BOOLEAN:
return utils.cast_to_boolean(value)
return value

Expand Down
2 changes: 1 addition & 1 deletion superset/connectors/druid/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1533,7 +1533,7 @@ def get_filters(
eq = cls.filter_values_handler(
eq,
is_list_target=is_list_target,
target_column_type=utils.GenericDataType.NUMERIC
target_generic_type=utils.GenericDataType.NUMERIC
if is_numeric_col
else utils.GenericDataType.STRING,
)
Expand Down
20 changes: 13 additions & 7 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,33 +1365,39 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
)
elif col_obj:
sqla_col = col_obj.get_sqla_col()
col_type = col_obj.type if col_obj else None
col_spec = db_engine_spec.get_column_spec(
col_obj.type if col_obj else None
native_type=col_type,
db_extra=self.database.get_extra(),
)
is_list_target = op in (
utils.FilterOperator.IN.value,
utils.FilterOperator.NOT_IN.value,
)
if col_spec:
target_type = col_spec.generic_type
target_generic_type = col_spec.generic_type
else:
target_type = GenericDataType.STRING
target_generic_type = GenericDataType.STRING
eq = self.filter_values_handler(
values=val,
target_column_type=target_type,
target_generic_type=target_generic_type,
target_native_type=col_type,
is_list_target=is_list_target,
db_engine_spec=db_engine_spec,
db_extra=self.database.get_extra(),
)
if is_list_target:
assert isinstance(eq, (tuple, list))
if len(eq) == 0:
raise QueryObjectValidationError(
_("Filter value list cannot be empty")
)
if None in eq:
eq = [x for x in eq if x is not None]
if len(eq) > len(
eq_without_none := [x for x in eq if x is not None]
):
is_null_cond = sqla_col.is_(None)
if eq:
cond = or_(is_null_cond, sqla_col.in_(eq))
cond = or_(is_null_cond, sqla_col.in_(eq_without_none))
else:
cond = is_null_cond
else:
Expand Down
4 changes: 2 additions & 2 deletions superset/db_engine_specs/athena.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def convert_dttm(
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
return f"from_iso8601_date('{dttm.date().isoformat()}')"
return f"DATE '{dttm.date().isoformat()}'"
if tt == utils.TemporalType.TIMESTAMP:
datetime_formatted = dttm.isoformat(timespec="microseconds")
return f"""from_iso8601_timestamp('{datetime_formatted}')"""
return f"""TIMESTAMP '{datetime_formatted}'"""
return None

@classmethod
Expand Down
9 changes: 1 addition & 8 deletions superset/db_engine_specs/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@
from superset import security_manager, sql_parse
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.models.sql_types.base import literal_dttm_type_factory
from superset.sql_parse import ParsedQuery, Table
from superset.superset_typing import ResultSetColumnType
from superset.utils import core as utils
Expand Down Expand Up @@ -268,7 +267,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
),
(
re.compile(r"^date", re.IGNORECASE),
types.DateTime(),
types.Date(),
GenericDataType.TEMPORAL,
),
(
Expand Down Expand Up @@ -1476,12 +1475,6 @@ def get_column_spec( # pylint: disable=unused-argument
)
if col_types:
column_type, generic_type = col_types
# wrap temporal types in custom type that supports literal binding
# using datetimes
if generic_type == GenericDataType.TEMPORAL:
column_type = literal_dttm_type_factory(
column_type, cls, native_type or "", db_extra=db_extra or {}
)
is_dttm = generic_type == GenericDataType.TEMPORAL
return ColumnSpec(
sqla_type=column_type, generic_type=generic_type, is_dttm=is_dttm
Expand Down
6 changes: 3 additions & 3 deletions superset/db_engine_specs/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def _show_columns(
),
(
re.compile(r"^date.*", re.IGNORECASE),
types.DATETIME(),
types.DATE(),
GenericDataType.TEMPORAL,
),
(
Expand Down Expand Up @@ -753,12 +753,12 @@ def convert_dttm(
"""
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
return f"""from_iso8601_date('{dttm.date().isoformat()}')"""
return f"""DATE '{dttm.date().isoformat()}'"""
if tt in (
utils.TemporalType.TIMESTAMP,
utils.TemporalType.TIMESTAMP_WITH_TIME_ZONE,
):
return f"""from_iso8601_timestamp('{dttm.isoformat(timespec="microseconds")}')""" # pylint: disable=line-too-long,useless-suppression
return f"""TIMESTAMP '{dttm.isoformat(timespec="microseconds")}'"""
return None

@classmethod
Expand Down
4 changes: 2 additions & 2 deletions superset/db_engine_specs/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,12 +70,12 @@ def convert_dttm(
"""
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
return f"from_iso8601_date('{dttm.date().isoformat()}')"
return f"DATE '{dttm.date().isoformat()}'"
if tt in (
utils.TemporalType.TIMESTAMP,
utils.TemporalType.TIMESTAMP_WITH_TIME_ZONE,
):
return f"""from_iso8601_timestamp('{dttm.isoformat(timespec="microseconds")}')""" # pylint: disable=line-too-long,useless-suppression
return f"""TIMESTAMP '{dttm.isoformat(timespec="microseconds")}'"""
return None

@classmethod
Expand Down
65 changes: 0 additions & 65 deletions superset/models/sql_types/base.py

This file was deleted.

6 changes: 2 additions & 4 deletions tests/integration_tests/db_engine_specs/presto_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -517,12 +517,12 @@ def test_convert_dttm(self):

self.assertEqual(
PrestoEngineSpec.convert_dttm("DATE", dttm),
"from_iso8601_date('2019-01-02')",
"DATE '2019-01-02'",
)

self.assertEqual(
PrestoEngineSpec.convert_dttm("TIMESTAMP", dttm),
"from_iso8601_timestamp('2019-01-02T03:04:05.678900')",
"TIMESTAMP '2019-01-02T03:04:05.678900'",
)

def test_query_cost_formatter(self):
Expand Down Expand Up @@ -672,12 +672,10 @@ def test_get_sqla_column_type(self):

column_spec = PrestoEngineSpec.get_column_spec("time")
assert isinstance(column_spec.sqla_type, types.Time)
assert type(column_spec.sqla_type).__name__ == "TemporalWrapperType"
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)

column_spec = PrestoEngineSpec.get_column_spec("timestamp")
assert isinstance(column_spec.sqla_type, types.TIMESTAMP)
assert type(column_spec.sqla_type).__name__ == "TemporalWrapperType"
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)

sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/db_engine_specs/trino_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,12 @@ def test_convert_dttm(self):

self.assertEqual(
TrinoEngineSpec.convert_dttm("DATE", dttm),
"from_iso8601_date('2019-01-02')",
"DATE '2019-01-02'",
)

self.assertEqual(
TrinoEngineSpec.convert_dttm("TIMESTAMP", dttm),
"from_iso8601_timestamp('2019-01-02T03:04:05.678900')",
"TIMESTAMP '2019-01-02T03:04:05.678900'",
)

def test_adjust_database_uri(self):
Expand Down
24 changes: 3 additions & 21 deletions tests/integration_tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
from superset.common.db_query_status import QueryStatus
from superset.models.core import Database
from superset.models.slice import Slice
from superset.models.sql_types.base import literal_dttm_type_factory
from superset.utils.database import get_example_database

from .base_tests import SupersetTestCase
Expand Down Expand Up @@ -376,30 +375,22 @@ def test_get_sqla_engine(self, mocked_create_engine):
class TestSqlaTableModel(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_timestamp_expression(self):
col_type = (
"VARCHAR"
if get_example_database().backend == "presto"
else "TemporalWrapperType"
)
tbl = self.get_table(name="birth_names")
ds_col = tbl.get_column("ds")
sqla_literal = ds_col.get_timestamp_expression(None)
self.assertEqual(str(sqla_literal.compile()), "ds")
assert type(sqla_literal.type).__name__ == col_type
assert str(sqla_literal.compile()) == "ds"

sqla_literal = ds_col.get_timestamp_expression("P1D")
assert type(sqla_literal.type).__name__ == col_type
compiled = "{}".format(sqla_literal.compile())
if tbl.database.backend == "mysql":
self.assertEqual(compiled, "DATE(ds)")
assert compiled == "DATE(ds)"

prev_ds_expr = ds_col.expression
ds_col.expression = "DATE_ADD(ds, 1)"
sqla_literal = ds_col.get_timestamp_expression("P1D")
assert type(sqla_literal.type).__name__ == col_type
compiled = "{}".format(sqla_literal.compile())
if tbl.database.backend == "mysql":
self.assertEqual(compiled, "DATE(DATE_ADD(ds, 1))")
assert compiled == "DATE(DATE_ADD(ds, 1))"
ds_col.expression = prev_ds_expr

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
Expand Down Expand Up @@ -615,12 +606,3 @@ def test_data_for_slices_with_adhoc_column(self):

# clean up and auto commit
metadata_db.session.delete(slc)


def test_literal_dttm_type_factory():
orig_type = DateTime()
new_type = literal_dttm_type_factory(
orig_type, PostgresEngineSpec, "TIMESTAMP", db_extra={}
)
assert type(new_type).__name__ == "TemporalWrapperType"
assert str(new_type) == str(orig_type)
Loading

0 comments on commit 9ca53b8

Please sign in to comment.