Skip to content

Commit

Permalink
fix: Normalize prequery result type (#17312)
Browse files Browse the repository at this point in the history
Co-authored-by: John Bodley <[email protected]>
  • Loading branch information
2 people authored and AAfghahi committed Jan 10, 2022
1 parent be724cd commit 594ff5c
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 42 deletions.
46 changes: 37 additions & 9 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)

import dateutil.parser
import numpy as np
import pandas as pd
import sqlalchemy as sa
import sqlparse
Expand Down Expand Up @@ -1455,6 +1456,39 @@ def _get_series_orderby(
)
return ob

def _normalize_prequery_result_type(
self, row: pd.Series, dimension: str, columns_by_name: Dict[str, TableColumn],
) -> Union[str, int, float, bool, Text]:
"""
Convert a prequery result type to its equivalent Python type.
Some databases like Druid will return timestamps as strings, but do not perform
automatic casting when comparing these strings to a timestamp. For cases like
this we convert the value via the appropriate SQL transform.
:param row: A prequery record
:param dimension: The dimension name
:param columns_by_name: The mapping of columns by name
:return: equivalent primitive python type
"""

value = row[dimension]

if isinstance(value, np.generic):
value = value.item()

column_ = columns_by_name[dimension]

if column_.type and column_.is_temporal and isinstance(value, str):
sql = self.db_engine_spec.convert_dttm(
column_.type, dateutil.parser.parse(value),
)

if sql:
value = text(sql)

return value

def _get_top_groups(
self,
df: pd.DataFrame,
Expand All @@ -1466,15 +1500,9 @@ def _get_top_groups(
for _unused, row in df.iterrows():
group = []
for dimension in dimensions:
value = utils.normalize_prequery_result_type(row[dimension])

# Some databases like Druid will return timestamps as strings, but
# do not perform automatic casting when comparing these strings to
# a timestamp. For cases like this we convert the value from a
# string into a timestamp.
if columns_by_name[dimension].is_temporal and isinstance(value, str):
dttm = dateutil.parser.parse(value)
value = text(self.db_engine_spec.convert_dttm("TIMESTAMP", dttm))
value = self._normalize_prequery_result_type(
row, dimension, columns_by_name,
)

group.append(groupby_exprs[dimension] == value)
groups.append(and_(*group))
Expand Down
32 changes: 0 additions & 32 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1813,35 +1813,3 @@ def escape_sqla_query_binds(sql: str) -> str:
sql = sql.replace(bind, bind.replace(":", "\\:"))
processed_binds.add(bind)
return sql


def normalize_prequery_result_type(
value: Union[str, int, float, bool, np.generic]
) -> Union[str, int, float, bool]:
"""
Convert a value that is potentially a numpy type into its equivalent Python type.
:param value: primitive datatype in either numpy or python format
:return: equivalent primitive python type
>>> normalize_prequery_result_type('abc')
'abc'
>>> normalize_prequery_result_type(True)
True
>>> normalize_prequery_result_type(123)
123
>>> normalize_prequery_result_type(np.int16(123))
123
>>> normalize_prequery_result_type(np.uint32(123))
123
>>> normalize_prequery_result_type(np.int64(123))
123
>>> normalize_prequery_result_type(123.456)
123.456
>>> normalize_prequery_result_type(np.float32(123.456))
123.45600128173828
>>> normalize_prequery_result_type(np.float64(123.456))
123.456
"""
if isinstance(value, np.generic):
return value.item()
return value
77 changes: 76 additions & 1 deletion tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,18 @@
# under the License.
# isort:skip_file
import re
from typing import Any, Dict, NamedTuple, List, Pattern, Tuple, Union
from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Tuple, Union
from unittest.mock import patch
import pytest

import numpy as np
import pandas as pd
import sqlalchemy as sa
from flask import Flask
from pytest_mock import MockFixture
from sqlalchemy.sql import text
from sqlalchemy.sql.elements import TextClause

from superset import db
from superset.connectors.sqla.models import SqlaTable, TableColumn
Expand All @@ -33,6 +40,7 @@
FilterOperator,
GenericDataType,
get_example_database,
TemporalType,
)
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
Expand Down Expand Up @@ -484,3 +492,70 @@ def test_values_for_column(self):
)
assert None not in without_null
assert len(without_null) == 2


@pytest.mark.parametrize(
"row,dimension,result",
[
(pd.Series({"foo": "abc"}), "foo", "abc"),
(pd.Series({"bar": True}), "bar", True),
(pd.Series({"baz": 123}), "baz", 123),
(pd.Series({"baz": np.int16(123)}), "baz", 123),
(pd.Series({"baz": np.uint32(123)}), "baz", 123),
(pd.Series({"baz": np.int64(123)}), "baz", 123),
(pd.Series({"qux": 123.456}), "qux", 123.456),
(pd.Series({"qux": np.float32(123.456)}), "qux", 123.45600128173828),
(pd.Series({"qux": np.float64(123.456)}), "qux", 123.456),
(pd.Series({"quux": "2021-01-01"}), "quux", "2021-01-01"),
(
pd.Series({"quuz": "2021-01-01T00:00:00"}),
"quuz",
text("TIME_PARSE('2021-01-01T00:00:00')"),
),
],
)
def test__normalize_prequery_result_type(
app_context: Flask,
mocker: MockFixture,
row: pd.Series,
dimension: str,
result: Any,
) -> None:
def _convert_dttm(target_type: str, dttm: datetime) -> Optional[str]:
if target_type.upper() == TemporalType.TIMESTAMP:
return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')"""

return None

table = SqlaTable(table_name="foobar", database=get_example_database())
mocker.patch.object(table.db_engine_spec, "convert_dttm", new=_convert_dttm)

columns_by_name = {
"foo": TableColumn(
column_name="foo", is_dttm=False, table=table, type="STRING",
),
"bar": TableColumn(
column_name="bar", is_dttm=False, table=table, type="BOOLEAN",
),
"baz": TableColumn(
column_name="baz", is_dttm=False, table=table, type="INTEGER",
),
"qux": TableColumn(
column_name="qux", is_dttm=False, table=table, type="FLOAT",
),
"quux": TableColumn(
column_name="quuz", is_dttm=True, table=table, type="STRING",
),
"quuz": TableColumn(
column_name="quux", is_dttm=True, table=table, type="TIMESTAMP",
),
}

normalized = table._normalize_prequery_result_type(row, dimension, columns_by_name,)

assert type(normalized) == type(result)

if isinstance(normalized, TextClause):
assert str(normalized) == str(result)
else:
assert normalized == result

0 comments on commit 594ff5c

Please sign in to comment.