From 8dd6248a832ac44df0bf09d2770a60ff6b1856ab Mon Sep 17 00:00:00 2001 From: ofekisr Date: Thu, 18 Nov 2021 15:16:43 +0200 Subject: [PATCH] refactor: queryObject - add QueryObjectFactory --- superset/common/query_context.py | 6 +- superset/common/query_object.py | 133 +++++++++++------- .../integration_tests/query_context_tests.py | 1 + 3 files changed, 88 insertions(+), 52 deletions(-) diff --git a/superset/common/query_context.py b/superset/common/query_context.py index 12a27b77ae7b3..6dc28708eaa63 100644 --- a/superset/common/query_context.py +++ b/superset/common/query_context.py @@ -32,7 +32,7 @@ from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.common.db_query_status import QueryStatus from superset.common.query_actions import get_query_results -from superset.common.query_object import QueryObject +from superset.common.query_object import QueryObject, QueryObjectFactory from superset.common.utils import QueryCacheManager from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry @@ -102,8 +102,10 @@ def __init__( ) self.result_type = result_type or ChartDataResultType.FULL self.result_format = result_format or ChartDataResultFormat.JSON + query_object_factory = QueryObjectFactory() self.queries = [ - QueryObject(self.result_type, **query_obj) for query_obj in queries + query_object_factory.create(self.result_type, **query_obj) + for query_obj in queries ] self.force = force self.custom_cache_timeout = custom_cache_timeout diff --git a/superset/common/query_object.py b/superset/common/query_object.py index fd4f6f06525ff..e7480370260d3 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -14,19 +14,18 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=invalid-name +# pylint: disable=invalid-name, no-self-use from __future__ import annotations import logging from datetime import datetime, timedelta -from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TYPE_CHECKING from flask_babel import gettext as _ from pandas import DataFrame from superset import app, db from superset.common.chart_data import ChartDataResultType -from superset.connectors.base.models import BaseDatasource from superset.connectors.connector_registry import ConnectorRegistry from superset.exceptions import QueryObjectValidationError from superset.typing import Column, Metric, OrderBy @@ -47,7 +46,7 @@ from superset.views.utils import get_time_range_endpoints if TYPE_CHECKING: - from superset.common.query_context import QueryContext # pragma: no cover + from superset.connectors.base.models import BaseDatasource config = app.config @@ -111,14 +110,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes time_range: Optional[str] to_dttm: Optional[datetime] - def __init__( # pylint: disable=too-many-arguments,too-many-locals + def __init__( # pylint: disable=too-many-locals self, - parent_result_type: ChartDataResultType, + *, annotation_layers: Optional[List[Dict[str, Any]]] = None, applied_time_extras: Optional[Dict[str, str]] = None, apply_fetch_values_predicate: bool = False, columns: Optional[List[Column]] = None, - datasource: Optional[DatasourceDict] = None, + datasource: Optional[BaseDatasource] = None, extras: Optional[Dict[str, Any]] = None, filters: Optional[List[QueryObjectFilterClause]] = None, granularity: Optional[str] = None, @@ -128,7 +127,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals order_desc: bool = True, orderby: Optional[List[OrderBy]] = None, post_processing: Optional[List[Optional[Dict[str, Any]]]] = None, - row_limit: Optional[int] = None, + row_limit: int, row_offset: Optional[int] = None, series_columns: Optional[List[Column]] = None, series_limit: int = 0, @@ -137,13 +136,12 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals time_shift: Optional[str] = None, **kwargs: Any, ): - self.result_type = kwargs.get("result_type", parent_result_type) self._set_annotation_layers(annotation_layers) self.applied_time_extras = applied_time_extras or {} self.apply_fetch_values_predicate = apply_fetch_values_predicate or False self.columns = columns or [] - self._set_datasource(datasource) - self._set_extras(extras) + self.datasource = datasource + self.extras = extras or {} self.filter = filters or [] self.granularity = granularity self.is_rowcount = is_rowcount @@ -152,14 +150,16 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals self.order_desc = order_desc self.orderby = orderby or [] self._set_post_processing(post_processing) - self._set_row_limit(row_limit) + self.row_limit = row_limit self.row_offset = row_offset or 0 self._init_series_columns(series_columns, metrics, is_timeseries) self.series_limit = series_limit self.series_limit_metric = series_limit_metric - self.set_dttms(time_range, time_shift) self.time_range = time_range self.time_shift = parse_human_timedelta(time_shift) + self.from_dttm = kwargs.get("from_dttm") + self.to_dttm = kwargs.get("to_dttm") + self.result_type = kwargs.get("result_type") self.time_offsets = kwargs.get("time_offsets", []) self.inner_from_dttm = kwargs.get("inner_from_dttm") self.inner_to_dttm = kwargs.get("inner_to_dttm") @@ -176,20 +176,6 @@ def _set_annotation_layers( if layer["annotationType"] != "FORMULA" ] - def _set_datasource(self, datasource: Optional[DatasourceDict]) -> None: - self.datasource = None - if datasource: - self.datasource = ConnectorRegistry.get_datasource( - str(datasource["type"]), int(datasource["id"]), db.session - ) - - def _set_extras(self, extras: Optional[Dict[str, Any]]) -> None: - self.extras = extras or {} - if config["SIP_15_ENABLED"]: - self.extras["time_range_endpoints"] = get_time_range_endpoints( - form_data=self.extras - ) - def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None: # is_timeseries is True if time column is in either columns or groupby # (both are dimensions) @@ -212,17 +198,8 @@ def is_str_or_adhoc(metric: Metric) -> bool: def _set_post_processing( self, post_processing: Optional[List[Optional[Dict[str, Any]]]] ) -> None: - self.post_processing = [ - post_proc for post_proc in post_processing or [] if post_proc - ] - - def _set_row_limit(self, row_limit: Optional[int]) -> None: - default_row_limit = ( - config["SAMPLES_ROW_LIMIT"] - if self.result_type == ChartDataResultType.SAMPLES - else config["ROW_LIMIT"] - ) - self.row_limit = apply_max_row_limit(row_limit or default_row_limit) + post_processing = post_processing or [] + self.post_processing = [post_proc for post_proc in post_processing if post_proc] def _init_series_columns( self, @@ -237,18 +214,6 @@ def _init_series_columns( else: self.series_columns = [] - def set_dttms(self, time_range: Optional[str], time_shift: Optional[str]) -> None: - self.from_dttm, self.to_dttm = get_since_until( - relative_start=self.extras.get( - "relative_start", config["DEFAULT_RELATIVE_START_TIME"] - ), - relative_end=self.extras.get( - "relative_end", config["DEFAULT_RELATIVE_END_TIME"] - ), - time_range=time_range, - time_shift=time_shift, - ) - def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None: # rename deprecated fields for field in DEPRECATED_FIELDS: @@ -439,3 +404,71 @@ def exec_post_processing(self, df: DataFrame) -> DataFrame: options = post_process.get("options", {}) df = getattr(pandas_postprocessing, operation)(df, **options) return df + + +class QueryObjectFactory: # pylint: disable=too-few-public-methods + def create( # pylint: disable=too-many-arguments + self, + parent_result_type: ChartDataResultType, + datasource: Optional[DatasourceDict] = None, + extras: Optional[Dict[str, Any]] = None, + row_limit: Optional[int] = None, + time_range: Optional[str] = None, + time_shift: Optional[str] = None, + **kwargs: Any, + ) -> QueryObject: + datasource_model_instance = None + if datasource: + datasource_model_instance = self._convert_to_model(datasource) + processed_extras = self._process_extras(extras) + result_type = kwargs.setdefault("result_type", parent_result_type) + row_limit = self._process_row_limit(row_limit, result_type) + from_dttm, to_dttm = self._get_dttms(time_range, time_shift, processed_extras) + kwargs["from_dttm"] = from_dttm + kwargs["to_dttm"] = to_dttm + return QueryObject( + datasource=datasource_model_instance, + extras=extras, + row_limit=row_limit, + time_range=time_range, + time_shift=time_shift, + **kwargs, + ) + + def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource: + return ConnectorRegistry.get_datasource( + str(datasource["type"]), int(datasource["id"]), db.session + ) + + def _process_extras(self, extras: Optional[Dict[str, Any]]) -> Dict[str, Any]: + extras = extras or {} + if config["SIP_15_ENABLED"]: + extras["time_range_endpoints"] = get_time_range_endpoints(form_data=extras) + return extras + + def _process_row_limit( + self, row_limit: Optional[int], result_type: ChartDataResultType + ) -> int: + default_row_limit = ( + config["SAMPLES_ROW_LIMIT"] + if result_type == ChartDataResultType.SAMPLES + else config["ROW_LIMIT"] + ) + return apply_max_row_limit(row_limit or default_row_limit) + + def _get_dttms( + self, + time_range: Optional[str], + time_shift: Optional[str], + extras: Dict[str, Any], + ) -> Tuple[Optional[datetime], Optional[datetime]]: + return get_since_until( + relative_start=extras.get( + "relative_start", config["DEFAULT_RELATIVE_START_TIME"] + ), + relative_end=extras.get( + "relative_end", config["DEFAULT_RELATIVE_END_TIME"] + ), + time_range=time_range, + time_shift=time_shift, + ) diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index d7d93696cd4fe..6df246db74e96 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -50,6 +50,7 @@ def get_sql_text(payload: Dict[str, Any]) -> str: class TestQueryContext(SupersetTestCase): + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_schema_deserialization(self): """ Ensure that the deserialized QueryContext contains all required fields.