diff --git a/superset/common/query_context.py b/superset/common/query_context.py index d7eb62c743381..12a27b77ae7b3 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -80,10 +80,10 @@ class QueryContext: datasource: BaseDatasource queries: List[QueryObject] - force: bool - custom_cache_timeout: Optional[int] result_type: ChartDataResultType result_format: ChartDataResultFormat + force: bool + custom_cache_timeout: Optional[int] # TODO: Type datasource and query_object dictionary with TypedDict when it becomes # a vanilla python type https://github.com/python/mypy/issues/5288 @@ -92,19 +92,21 @@ def __init__( self, datasource: DatasourceDict, queries: List[Dict[str, Any]], - force: bool = False, - custom_cache_timeout: Optional[int] = None, result_type: Optional[ChartDataResultType] = None, result_format: Optional[ChartDataResultFormat] = None, + force: bool = False, + custom_cache_timeout: Optional[int] = None, ) -> None: self.datasource = ConnectorRegistry.get_datasource( str(datasource["type"]), int(datasource["id"]), db.session ) - self.force = force - self.custom_cache_timeout = custom_cache_timeout self.result_type = result_type or ChartDataResultType.FULL self.result_format = result_format or ChartDataResultFormat.JSON - self.queries = [QueryObject(self, **query_obj) for query_obj in queries] + self.queries = [ + QueryObject(self.result_type, **query_obj) for query_obj in queries + ] + self.force = force + self.custom_cache_timeout = custom_cache_timeout self.cache_values = { "datasource": datasource, "queries": queries, diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 2894392c8b60c..fd4f6f06525ff 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -15,6 +15,8 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=invalid-name +from __future__ import annotations + import logging from datetime import datetime, timedelta from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING @@ -106,11 +108,12 @@ class QueryObject: # pylint: disable=too-many-instance-attributes series_limit_metric: Optional[Metric] time_offsets: List[str] time_shift: Optional[timedelta] + time_range: Optional[str] to_dttm: Optional[datetime] def __init__( # pylint: disable=too-many-arguments,too-many-locals self, - query_context: "QueryContext", + parent_result_type: ChartDataResultType, annotation_layers: Optional[List[Dict[str, Any]]] = None, applied_time_extras: Optional[Dict[str, str]] = None, apply_fetch_values_predicate: bool = False, @@ -125,7 +128,6 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals order_desc: bool = True, orderby: Optional[List[OrderBy]] = None, post_processing: Optional[List[Optional[Dict[str, Any]]]] = None, - result_type: Optional[ChartDataResultType] = None, row_limit: Optional[int] = None, row_offset: Optional[int] = None, series_columns: Optional[List[Column]] = None, @@ -135,88 +137,117 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals time_shift: Optional[str] = None, **kwargs: Any, ): - columns = columns or [] - extras = extras or {} - annotation_layers = annotation_layers or [] + self.result_type = kwargs.get("result_type", parent_result_type) + self._set_annotation_layers(annotation_layers) + self.applied_time_extras = applied_time_extras or {} + self.apply_fetch_values_predicate = apply_fetch_values_predicate or False + self.columns = columns or [] + self._set_datasource(datasource) + self._set_extras(extras) + self.filter = filters or [] + self.granularity = granularity + self.is_rowcount = is_rowcount + self._set_is_timeseries(is_timeseries) + self._set_metrics(metrics) + self.order_desc = order_desc + self.orderby = orderby or [] + self._set_post_processing(post_processing) + self._set_row_limit(row_limit) + self.row_offset = row_offset or 0 + self._init_series_columns(series_columns, metrics, is_timeseries) + self.series_limit = series_limit + self.series_limit_metric = series_limit_metric + self.set_dttms(time_range, time_shift) + self.time_range = time_range + self.time_shift = parse_human_timedelta(time_shift) self.time_offsets = kwargs.get("time_offsets", []) self.inner_from_dttm = kwargs.get("inner_from_dttm") self.inner_to_dttm = kwargs.get("inner_to_dttm") - if series_columns: - self.series_columns = series_columns - elif is_timeseries and metrics: - self.series_columns = columns - else: - self.series_columns = [] + self._rename_deprecated_fields(kwargs) + self._move_deprecated_extra_fields(kwargs) - self.is_rowcount = is_rowcount - self.datasource = None - if datasource: - self.datasource = ConnectorRegistry.get_datasource( - str(datasource["type"]), int(datasource["id"]), db.session - ) - self.result_type = result_type or query_context.result_type - self.apply_fetch_values_predicate = apply_fetch_values_predicate or False + def _set_annotation_layers( + self, annotation_layers: Optional[List[Dict[str, Any]]] + ) -> None: self.annotation_layers = [ layer - for layer in annotation_layers + for layer in (annotation_layers or []) # formula annotations don't affect the payload, hence can be dropped if layer["annotationType"] != "FORMULA" ] - self.applied_time_extras = applied_time_extras or {} - self.granularity = granularity - self.from_dttm, self.to_dttm = get_since_until( - relative_start=extras.get( - "relative_start", config["DEFAULT_RELATIVE_START_TIME"] - ), - relative_end=extras.get( - "relative_end", config["DEFAULT_RELATIVE_END_TIME"] - ), - time_range=time_range, - time_shift=time_shift, - ) + + def _set_datasource(self, datasource: Optional[DatasourceDict]) -> None: + self.datasource = None + if datasource: + self.datasource = ConnectorRegistry.get_datasource( + str(datasource["type"]), int(datasource["id"]), db.session + ) + + def _set_extras(self, extras: Optional[Dict[str, Any]]) -> None: + self.extras = extras or {} + if config["SIP_15_ENABLED"]: + self.extras["time_range_endpoints"] = get_time_range_endpoints( + form_data=self.extras + ) + + def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None: # is_timeseries is True if time column is in either columns or groupby # (both are dimensions) self.is_timeseries = ( - is_timeseries if is_timeseries is not None else DTTM_ALIAS in columns + is_timeseries if is_timeseries is not None else DTTM_ALIAS in self.columns ) - self.time_range = time_range - self.time_shift = parse_human_timedelta(time_shift) - self.post_processing = [ - post_proc for post_proc in post_processing or [] if post_proc - ] + def _set_metrics(self, metrics: Optional[List[Metric]] = None) -> None: # Support metric reference/definition in the format of # 1. 'metric_name' - name of predefined metric # 2. { label: 'label_name' } - legacy format for a predefined metric # 3. { expressionType: 'SIMPLE' | 'SQL', ... } - adhoc metric + def is_str_or_adhoc(metric: Metric) -> bool: + return isinstance(metric, str) or is_adhoc_metric(metric) + self.metrics = metrics and [ - x if isinstance(x, str) or is_adhoc_metric(x) else x["label"] # type: ignore - for x in metrics + x if is_str_or_adhoc(x) else x["label"] for x in metrics # type: ignore + ] + + def _set_post_processing( + self, post_processing: Optional[List[Optional[Dict[str, Any]]]] + ) -> None: + self.post_processing = [ + post_proc for post_proc in post_processing or [] if post_proc ] + def _set_row_limit(self, row_limit: Optional[int]) -> None: default_row_limit = ( config["SAMPLES_ROW_LIMIT"] if self.result_type == ChartDataResultType.SAMPLES else config["ROW_LIMIT"] ) self.row_limit = apply_max_row_limit(row_limit or default_row_limit) - self.row_offset = row_offset or 0 - self.filter = filters or [] - self.series_limit = series_limit - self.series_limit_metric = series_limit_metric - self.order_desc = order_desc - self.extras = extras - if config["SIP_15_ENABLED"]: - self.extras["time_range_endpoints"] = get_time_range_endpoints( - form_data=self.extras - ) - - self.columns = columns - self.orderby = orderby or [] + def _init_series_columns( + self, + series_columns: Optional[List[Column]], + metrics: Optional[List[Metric]], + is_timeseries: Optional[bool], + ) -> None: + if series_columns: + self.series_columns = series_columns + elif is_timeseries and metrics: + self.series_columns = self.columns + else: + self.series_columns = [] - self._rename_deprecated_fields(kwargs) - self._move_deprecated_extra_fields(kwargs) + def set_dttms(self, time_range: Optional[str], time_shift: Optional[str]) -> None: + self.from_dttm, self.to_dttm = get_since_until( + relative_start=self.extras.get( + "relative_start", config["DEFAULT_RELATIVE_START_TIME"] + ), + relative_end=self.extras.get( + "relative_end", config["DEFAULT_RELATIVE_END_TIME"] + ), + time_range=time_range, + time_shift=time_shift, + ) def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None: # rename deprecated fields