diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index cf7d95acd08f2..ffa04abae4c46 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -40,6 +40,7 @@ from superset.charts.post_processing import apply_post_process from superset.charts.schemas import ChartDataQueryContextSchema from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType +from superset.connectors.base.models import BaseDatasource from superset.exceptions import QueryObjectValidationError from superset.extensions import event_logger from superset.utils.async_query_manager import AsyncQueryTokenException @@ -158,7 +159,9 @@ def get_data(self, pk: int) -> Response: except (TypeError, json.decoder.JSONDecodeError): form_data = {} - return self._get_data_response(command, form_data=form_data) + return self._get_data_response( + command=command, form_data=form_data, datasource=query_context.datasource + ) @expose("/data", methods=["POST"]) @protect() @@ -327,7 +330,10 @@ def _run_async( return self.response(202, **result) def _send_chart_response( - self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None, + self, + result: Dict[Any, Any], + form_data: Optional[Dict[str, Any]] = None, + datasource: Optional[BaseDatasource] = None, ) -> Response: result_type = result["query_context"].result_type result_format = result["query_context"].result_format @@ -336,7 +342,7 @@ def _send_chart_response( # This is needed for sending reports based on text charts that do the # post-processing of data, eg, the pivot table. if result_type == ChartDataResultType.POST_PROCESSED: - result = apply_post_process(result, form_data) + result = apply_post_process(result, form_data, datasource) if result_format == ChartDataResultFormat.CSV: # Verify user has permission to export CSV file @@ -364,6 +370,7 @@ def _get_data_response( command: ChartDataCommand, force_cached: bool = False, form_data: Optional[Dict[str, Any]] = None, + datasource: Optional[BaseDatasource] = None, ) -> Response: try: result = command.run(force_cached=force_cached) @@ -372,7 +379,7 @@ def _get_data_response( except ChartDataQueryFailedError as exc: return self.response_400(message=exc.message) - return self._send_chart_response(result, form_data) + return self._send_chart_response(result, form_data, datasource) # pylint: disable=invalid-name, no-self-use def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]: diff --git a/superset/charts/post_processing.py b/superset/charts/post_processing.py index 7a7fdfe95dd5f..35d2aec9db6ef 100644 --- a/superset/charts/post_processing.py +++ b/superset/charts/post_processing.py @@ -27,13 +27,16 @@ """ from io import StringIO -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING import pandas as pd from superset.common.chart_data import ChartDataResultFormat from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name +if TYPE_CHECKING: + from superset.connectors.base.models import BaseDatasource + def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]: """ @@ -284,7 +287,9 @@ def table(df: pd.DataFrame, form_data: Dict[str, Any]) -> pd.DataFrame: def apply_post_process( - result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None, + result: Dict[Any, Any], + form_data: Optional[Dict[str, Any]] = None, + datasource: Optional["BaseDatasource"] = None, ) -> Dict[Any, Any]: form_data = form_data or {} @@ -306,7 +311,7 @@ def apply_post_process( query["colnames"] = list(processed_df.columns) query["indexnames"] = list(processed_df.index) - query["coltypes"] = extract_dataframe_dtypes(processed_df) + query["coltypes"] = extract_dataframe_dtypes(processed_df, datasource) query["rowcount"] = len(processed_df.index) # Flatten hierarchical columns/index since they are represented as diff --git a/superset/common/query_actions.py b/superset/common/query_actions.py index 9517eb7be3b1d..76ff5c767072f 100644 --- a/superset/common/query_actions.py +++ b/superset/common/query_actions.py @@ -104,7 +104,7 @@ def _get_full( if status != QueryStatus.FAILED: payload["colnames"] = list(df.columns) payload["indexnames"] = list(df.index) - payload["coltypes"] = extract_dataframe_dtypes(df) + payload["coltypes"] = extract_dataframe_dtypes(df, datasource) payload["data"] = query_context.get_data(df) payload["result_format"] = query_context.result_format del payload["df"] diff --git a/superset/utils/core.py b/superset/utils/core.py index 5b91132de694d..f864059db4175 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1597,7 +1597,9 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]: return [col for col in map(get_column_name_from_metric, metrics) if col] -def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]: +def extract_dataframe_dtypes( + df: pd.DataFrame, datasource: Optional["BaseDatasource"] = None, +) -> List[GenericDataType]: """Serialize pandas/numpy dtypes to generic types""" # omitting string types as those will be the default type @@ -1612,11 +1614,21 @@ def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]: "date": GenericDataType.TEMPORAL, } + columns_by_name = ( + {column.column_name: column for column in datasource.columns} + if datasource + else {} + ) generic_types: List[GenericDataType] = [] for column in df.columns: + column_object = columns_by_name.get(column) series = df[column] inferred_type = infer_dtype(series) - generic_type = inferred_type_map.get(inferred_type, GenericDataType.STRING) + generic_type = ( + GenericDataType.TEMPORAL + if column_object and column_object.is_dttm + else inferred_type_map.get(inferred_type, GenericDataType.STRING) + ) generic_types.append(generic_type) return generic_types diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index 5ac41beca5e4d..3fcbe959e023a 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -1121,7 +1121,9 @@ def test_get_form_data_token(self): generated_token = get_form_data_token({}) assert re.match(r"^token_[a-z0-9]{8}$", generated_token) is not None + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_extract_dataframe_dtypes(self): + slc = self.get_slice("Girls", db.session) cols: Tuple[Tuple[str, GenericDataType, List[Any]], ...] = ( ("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]), ( @@ -1147,10 +1149,13 @@ def test_extract_dataframe_dtypes(self): ("float_null", GenericDataType.NUMERIC, [None, 0.5]), ("bool_null", GenericDataType.BOOLEAN, [None, False]), ("obj_null", GenericDataType.STRING, [None, {"a": 1}]), + # Non-timestamp columns should be identified as temporal if + # `is_dttm` is set to `True` in the underlying datasource + ("ds", GenericDataType.TEMPORAL, [None, {"ds": "2017-01-01"}]), ) df = pd.DataFrame(data={col[0]: col[2] for col in cols}) - assert extract_dataframe_dtypes(df) == [col[1] for col in cols] + assert extract_dataframe_dtypes(df, slc.datasource) == [col[1] for col in cols] def test_normalize_dttm_col(self): def normalize_col(