diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index c3d440a624f64..b653693e7d63a 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -36,6 +36,7 @@ ) import dateutil.parser +import numpy as np import pandas as pd import sqlalchemy as sa import sqlparse @@ -1455,6 +1456,39 @@ def _get_series_orderby( ) return ob + def _normalize_prequery_result_type( + self, row: pd.Series, dimension: str, columns_by_name: Dict[str, TableColumn], + ) -> Union[str, int, float, bool, Text]: + """ + Convert a prequery result type to its equivalent Python type. + + Some databases like Druid will return timestamps as strings, but do not perform + automatic casting when comparing these strings to a timestamp. For cases like + this we convert the value via the appropriate SQL transform. + + :param row: A prequery record + :param dimension: The dimension name + :param columns_by_name: The mapping of columns by name + :return: equivalent primitive python type + """ + + value = row[dimension] + + if isinstance(value, np.generic): + value = value.item() + + column_ = columns_by_name[dimension] + + if column_.type and column_.is_temporal and isinstance(value, str): + sql = self.db_engine_spec.convert_dttm( + column_.type, dateutil.parser.parse(value), + ) + + if sql: + value = text(sql) + + return value + def _get_top_groups( self, df: pd.DataFrame, @@ -1466,15 +1500,9 @@ def _get_top_groups( for _unused, row in df.iterrows(): group = [] for dimension in dimensions: - value = utils.normalize_prequery_result_type(row[dimension]) - - # Some databases like Druid will return timestamps as strings, but - # do not perform automatic casting when comparing these strings to - # a timestamp. For cases like this we convert the value from a - # string into a timestamp. - if columns_by_name[dimension].is_temporal and isinstance(value, str): - dttm = dateutil.parser.parse(value) - value = text(self.db_engine_spec.convert_dttm("TIMESTAMP", dttm)) + value = self._normalize_prequery_result_type( + row, dimension, columns_by_name, + ) group.append(groupby_exprs[dimension] == value) groups.append(and_(*group)) diff --git a/superset/utils/core.py b/superset/utils/core.py index bb4dfdaf2e799..83191b2303d0b 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1813,35 +1813,3 @@ def escape_sqla_query_binds(sql: str) -> str: sql = sql.replace(bind, bind.replace(":", "\\:")) processed_binds.add(bind) return sql - - -def normalize_prequery_result_type( - value: Union[str, int, float, bool, np.generic] -) -> Union[str, int, float, bool]: - """ - Convert a value that is potentially a numpy type into its equivalent Python type. - - :param value: primitive datatype in either numpy or python format - :return: equivalent primitive python type - >>> normalize_prequery_result_type('abc') - 'abc' - >>> normalize_prequery_result_type(True) - True - >>> normalize_prequery_result_type(123) - 123 - >>> normalize_prequery_result_type(np.int16(123)) - 123 - >>> normalize_prequery_result_type(np.uint32(123)) - 123 - >>> normalize_prequery_result_type(np.int64(123)) - 123 - >>> normalize_prequery_result_type(123.456) - 123.456 - >>> normalize_prequery_result_type(np.float32(123.456)) - 123.45600128173828 - >>> normalize_prequery_result_type(np.float64(123.456)) - 123.456 - """ - if isinstance(value, np.generic): - return value.item() - return value diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index e92ba165687ce..03f2d88a0790e 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -16,11 +16,18 @@ # under the License. # isort:skip_file import re -from typing import Any, Dict, NamedTuple, List, Pattern, Tuple, Union +from datetime import datetime +from typing import Any, Dict, List, NamedTuple, Optional, Pattern, Tuple, Union from unittest.mock import patch import pytest +import numpy as np +import pandas as pd import sqlalchemy as sa +from flask import Flask +from pytest_mock import MockFixture +from sqlalchemy.sql import text +from sqlalchemy.sql.elements import TextClause from superset import db from superset.connectors.sqla.models import SqlaTable, TableColumn @@ -33,6 +40,7 @@ FilterOperator, GenericDataType, get_example_database, + TemporalType, ) from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, @@ -484,3 +492,70 @@ def test_values_for_column(self): ) assert None not in without_null assert len(without_null) == 2 + + +@pytest.mark.parametrize( + "row,dimension,result", + [ + (pd.Series({"foo": "abc"}), "foo", "abc"), + (pd.Series({"bar": True}), "bar", True), + (pd.Series({"baz": 123}), "baz", 123), + (pd.Series({"baz": np.int16(123)}), "baz", 123), + (pd.Series({"baz": np.uint32(123)}), "baz", 123), + (pd.Series({"baz": np.int64(123)}), "baz", 123), + (pd.Series({"qux": 123.456}), "qux", 123.456), + (pd.Series({"qux": np.float32(123.456)}), "qux", 123.45600128173828), + (pd.Series({"qux": np.float64(123.456)}), "qux", 123.456), + (pd.Series({"quux": "2021-01-01"}), "quux", "2021-01-01"), + ( + pd.Series({"quuz": "2021-01-01T00:00:00"}), + "quuz", + text("TIME_PARSE('2021-01-01T00:00:00')"), + ), + ], +) +def test__normalize_prequery_result_type( + app_context: Flask, + mocker: MockFixture, + row: pd.Series, + dimension: str, + result: Any, +) -> None: + def _convert_dttm(target_type: str, dttm: datetime) -> Optional[str]: + if target_type.upper() == TemporalType.TIMESTAMP: + return f"""TIME_PARSE('{dttm.isoformat(timespec="seconds")}')""" + + return None + + table = SqlaTable(table_name="foobar", database=get_example_database()) + mocker.patch.object(table.db_engine_spec, "convert_dttm", new=_convert_dttm) + + columns_by_name = { + "foo": TableColumn( + column_name="foo", is_dttm=False, table=table, type="STRING", + ), + "bar": TableColumn( + column_name="bar", is_dttm=False, table=table, type="BOOLEAN", + ), + "baz": TableColumn( + column_name="baz", is_dttm=False, table=table, type="INTEGER", + ), + "qux": TableColumn( + column_name="qux", is_dttm=False, table=table, type="FLOAT", + ), + "quux": TableColumn( + column_name="quuz", is_dttm=True, table=table, type="STRING", + ), + "quuz": TableColumn( + column_name="quux", is_dttm=True, table=table, type="TIMESTAMP", + ), + } + + normalized = table._normalize_prequery_result_type(row, dimension, columns_by_name,) + + assert type(normalized) == type(result) + + if isinstance(normalized, TextClause): + assert str(normalized) == str(result) + else: + assert normalized == result