Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sqla): Normalize prequery result type #17312

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 37 additions & 9 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)

import dateutil.parser
import numpy as np
import pandas as pd
import sqlalchemy as sa
import sqlparse
Expand Down Expand Up @@ -1455,6 +1456,39 @@ def _get_series_orderby(
)
return ob

def _normalize_prequery_result_type(
self, row: pd.Series, dimension: str, columns_by_name: Dict[str, TableColumn],
) -> Union[str, int, float, bool, Text]:
"""
Convert a prequery result type to its equivalent Python type.

Some databases like Druid will return timestamps as strings, but do not perform
automatic casting when comparing these strings to a timestamp. For cases like
this we convert the value via the appropriate SQL transform.

:param row: A prequery record
:param dimension: The dimension name
:param columns_by_name: The mapping of columns by name
:return: equivalent primitive python type
"""

value = row[dimension]

if isinstance(value, np.generic):
value = value.item()

column_ = columns_by_name[dimension]

if column_.type and column_.is_temporal and isinstance(value, str):
sql = self.db_engine_spec.convert_dttm(
column_.type, dateutil.parser.parse(value),
)

if sql:
value = text(sql)

return value

def _get_top_groups(
self,
df: pd.DataFrame,
Expand All @@ -1466,15 +1500,9 @@ def _get_top_groups(
for _unused, row in df.iterrows():
group = []
for dimension in dimensions:
value = utils.normalize_prequery_result_type(row[dimension])

# Some databases like Druid will return timestamps as strings, but
# do not perform automatic casting when comparing these strings to
# a timestamp. For cases like this we convert the value from a
# string into a timestamp.
if columns_by_name[dimension].is_temporal and isinstance(value, str):
dttm = dateutil.parser.parse(value)
value = text(self.db_engine_spec.convert_dttm("TIMESTAMP", dttm))
value = self._normalize_prequery_result_type(
row, dimension, columns_by_name,
)

group.append(groupby_exprs[dimension] == value)
groups.append(and_(*group))
Expand Down
32 changes: 0 additions & 32 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,35 +1813,3 @@ def escape_sqla_query_binds(sql: str) -> str:
sql = sql.replace(bind, bind.replace(":", "\\:"))
processed_binds.add(bind)
return sql


def normalize_prequery_result_type(
value: Union[str, int, float, bool, np.generic]
) -> Union[str, int, float, bool]:
"""
Convert a value that is potentially a numpy type into its equivalent Python type.

:param value: primitive datatype in either numpy or python format
:return: equivalent primitive python type
>>> normalize_prequery_result_type('abc')
'abc'
>>> normalize_prequery_result_type(True)
True
>>> normalize_prequery_result_type(123)
123
>>> normalize_prequery_result_type(np.int16(123))
123
>>> normalize_prequery_result_type(np.uint32(123))
123
>>> normalize_prequery_result_type(np.int64(123))
123
>>> normalize_prequery_result_type(123.456)
123.456
>>> normalize_prequery_result_type(np.float32(123.456))
123.45600128173828
>>> normalize_prequery_result_type(np.float64(123.456))
123.456
"""
if isinstance(value, np.generic):
return value.item()
return value
77 changes: 76 additions & 1 deletion tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@
# under the License.
# isort:skip_file
import re
from typing import Any, Dict, NamedTuple, List, Pattern, Tuple, Union
from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Tuple, Union
from unittest.mock import patch
import pytest

import numpy as np
import pandas as pd
import sqlalchemy as sa
from flask import Flask
from pytest_mock import MockFixture
from sqlalchemy.sql import text
from sqlalchemy.sql.elements import TextClause

from superset import db
from superset.connectors.sqla.models import SqlaTable, TableColumn
Expand All @@ -33,6 +40,7 @@
FilterOperator,
GenericDataType,
get_example_database,
TemporalType,
)
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
Expand Down Expand Up @@ -484,3 +492,70 @@ def test_values_for_column(self):
)
assert None not in without_null
assert len(without_null) == 2


@pytest.mark.parametrize(
"row,dimension,result",
[
(pd.Series({"foo": "abc"}), "foo", "abc"),
(pd.Series({"bar": True}), "bar", True),
(pd.Series({"baz": 123}), "baz", 123),
(pd.Series({"baz": np.int16(123)}), "baz", 123),
(pd.Series({"baz": np.uint32(123)}), "baz", 123),
(pd.Series({"baz": np.int64(123)}), "baz", 123),
(pd.Series({"qux": 123.456}), "qux", 123.456),
(pd.Series({"qux": np.float32(123.456)}), "qux", 123.45600128173828),
(pd.Series({"qux": np.float64(123.456)}), "qux", 123.456),
(pd.Series({"quux": "2021-01-01"}), "quux", "2021-01-01"),
(
pd.Series({"quuz": "2021-01-01T00:00:00"}),
"quuz",
text("TIME_PARSE('2021-01-01T00:00:00')"),
),
],
)
def test__normalize_prequery_result_type(
app_context: Flask,
mocker: MockFixture,
row: pd.Series,
dimension: str,
result: Any,
) -> None:
def _convert_dttm(target_type: str, dttm: datetime) -> Optional[str]:
if target_type.upper() == TemporalType.TIMESTAMP:
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""

return None

table = SqlaTable(table_name="foobar", database=get_example_database())
mocker.patch.object(table.db_engine_spec, "convert_dttm", new=_convert_dttm)

columns_by_name = {
"foo": TableColumn(
column_name="foo", is_dttm=False, table=table, type="STRING",
),
"bar": TableColumn(
column_name="bar", is_dttm=False, table=table, type="BOOLEAN",
),
"baz": TableColumn(
column_name="baz", is_dttm=False, table=table, type="INTEGER",
),
"qux": TableColumn(
column_name="qux", is_dttm=False, table=table, type="FLOAT",
),
"quux": TableColumn(
column_name="quuz", is_dttm=True, table=table, type="STRING",
),
"quuz": TableColumn(
column_name="quux", is_dttm=True, table=table, type="TIMESTAMP",
),
}

normalized = table._normalize_prequery_result_type(row, dimension, columns_by_name,)

assert type(normalized) == type(result)

if isinstance(normalized, TextClause):
assert str(normalized) == str(result)
else:
assert normalized == result