Skip to content

Commit

Permalink
refactor: queryObject - add QueryObjectFactory
Browse files Browse the repository at this point in the history
  • Loading branch information
ofekisr committed Nov 18, 2021
1 parent b914e2d commit 8dd6248
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 52 deletions.
6 changes: 4 additions & 2 deletions superset/common/query_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.common.query_actions import get_query_results
from superset.common.query_object import QueryObject
from superset.common.query_object import QueryObject, QueryObjectFactory
from superset.common.utils import QueryCacheManager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
Expand Down Expand Up @@ -102,8 +102,10 @@ def __init__(
)
self.result_type = result_type or ChartDataResultType.FULL
self.result_format = result_format or ChartDataResultFormat.JSON
query_object_factory = QueryObjectFactory()
self.queries = [
QueryObject(self.result_type, **query_obj) for query_obj in queries
query_object_factory.create(self.result_type, **query_obj)
for query_obj in queries
]
self.force = force
self.custom_cache_timeout = custom_cache_timeout
Expand Down
133 changes: 83 additions & 50 deletions superset/common/query_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name
# pylint: disable=invalid-name, no-self-use
from __future__ import annotations

import logging
from datetime import datetime, timedelta
from typing import Any, Dict, List, NamedTuple, Optional, TYPE_CHECKING
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, TYPE_CHECKING

from flask_babel import gettext as _
from pandas import DataFrame

from superset import app, db
from superset.common.chart_data import ChartDataResultType
from superset.connectors.base.models import BaseDatasource
from superset.connectors.connector_registry import ConnectorRegistry
from superset.exceptions import QueryObjectValidationError
from superset.typing import Column, Metric, OrderBy
Expand All @@ -47,7 +46,7 @@
from superset.views.utils import get_time_range_endpoints

if TYPE_CHECKING:
from superset.common.query_context import QueryContext # pragma: no cover
from superset.connectors.base.models import BaseDatasource


config = app.config
Expand Down Expand Up @@ -111,14 +110,14 @@ class QueryObject: # pylint: disable=too-many-instance-attributes
time_range: Optional[str]
to_dttm: Optional[datetime]

def __init__( # pylint: disable=too-many-arguments,too-many-locals
def __init__( # pylint: disable=too-many-locals
self,
parent_result_type: ChartDataResultType,
*,
annotation_layers: Optional[List[Dict[str, Any]]] = None,
applied_time_extras: Optional[Dict[str, str]] = None,
apply_fetch_values_predicate: bool = False,
columns: Optional[List[Column]] = None,
datasource: Optional[DatasourceDict] = None,
datasource: Optional[BaseDatasource] = None,
extras: Optional[Dict[str, Any]] = None,
filters: Optional[List[QueryObjectFilterClause]] = None,
granularity: Optional[str] = None,
Expand All @@ -128,7 +127,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
order_desc: bool = True,
orderby: Optional[List[OrderBy]] = None,
post_processing: Optional[List[Optional[Dict[str, Any]]]] = None,
row_limit: Optional[int] = None,
row_limit: int,
row_offset: Optional[int] = None,
series_columns: Optional[List[Column]] = None,
series_limit: int = 0,
Expand All @@ -137,13 +136,12 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
time_shift: Optional[str] = None,
**kwargs: Any,
):
self.result_type = kwargs.get("result_type", parent_result_type)
self._set_annotation_layers(annotation_layers)
self.applied_time_extras = applied_time_extras or {}
self.apply_fetch_values_predicate = apply_fetch_values_predicate or False
self.columns = columns or []
self._set_datasource(datasource)
self._set_extras(extras)
self.datasource = datasource
self.extras = extras or {}
self.filter = filters or []
self.granularity = granularity
self.is_rowcount = is_rowcount
Expand All @@ -152,14 +150,16 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals
self.order_desc = order_desc
self.orderby = orderby or []
self._set_post_processing(post_processing)
self._set_row_limit(row_limit)
self.row_limit = row_limit
self.row_offset = row_offset or 0
self._init_series_columns(series_columns, metrics, is_timeseries)
self.series_limit = series_limit
self.series_limit_metric = series_limit_metric
self.set_dttms(time_range, time_shift)
self.time_range = time_range
self.time_shift = parse_human_timedelta(time_shift)
self.from_dttm = kwargs.get("from_dttm")
self.to_dttm = kwargs.get("to_dttm")
self.result_type = kwargs.get("result_type")
self.time_offsets = kwargs.get("time_offsets", [])
self.inner_from_dttm = kwargs.get("inner_from_dttm")
self.inner_to_dttm = kwargs.get("inner_to_dttm")
Expand All @@ -176,20 +176,6 @@ def _set_annotation_layers(
if layer["annotationType"] != "FORMULA"
]

def _set_datasource(self, datasource: Optional[DatasourceDict]) -> None:
self.datasource = None
if datasource:
self.datasource = ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)

def _set_extras(self, extras: Optional[Dict[str, Any]]) -> None:
self.extras = extras or {}
if config["SIP_15_ENABLED"]:
self.extras["time_range_endpoints"] = get_time_range_endpoints(
form_data=self.extras
)

def _set_is_timeseries(self, is_timeseries: Optional[bool]) -> None:
# is_timeseries is True if time column is in either columns or groupby
# (both are dimensions)
Expand All @@ -212,17 +198,8 @@ def is_str_or_adhoc(metric: Metric) -> bool:
def _set_post_processing(
self, post_processing: Optional[List[Optional[Dict[str, Any]]]]
) -> None:
self.post_processing = [
post_proc for post_proc in post_processing or [] if post_proc
]

def _set_row_limit(self, row_limit: Optional[int]) -> None:
default_row_limit = (
config["SAMPLES_ROW_LIMIT"]
if self.result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)
self.row_limit = apply_max_row_limit(row_limit or default_row_limit)
post_processing = post_processing or []
self.post_processing = [post_proc for post_proc in post_processing if post_proc]

def _init_series_columns(
self,
Expand All @@ -237,18 +214,6 @@ def _init_series_columns(
else:
self.series_columns = []

def set_dttms(self, time_range: Optional[str], time_shift: Optional[str]) -> None:
self.from_dttm, self.to_dttm = get_since_until(
relative_start=self.extras.get(
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=self.extras.get(
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)

def _rename_deprecated_fields(self, kwargs: Dict[str, Any]) -> None:
# rename deprecated fields
for field in DEPRECATED_FIELDS:
Expand Down Expand Up @@ -439,3 +404,71 @@ def exec_post_processing(self, df: DataFrame) -> DataFrame:
options = post_process.get("options", {})
df = getattr(pandas_postprocessing, operation)(df, **options)
return df


class QueryObjectFactory: # pylint: disable=too-few-public-methods
def create( # pylint: disable=too-many-arguments
self,
parent_result_type: ChartDataResultType,
datasource: Optional[DatasourceDict] = None,
extras: Optional[Dict[str, Any]] = None,
row_limit: Optional[int] = None,
time_range: Optional[str] = None,
time_shift: Optional[str] = None,
**kwargs: Any,
) -> QueryObject:
datasource_model_instance = None
if datasource:
datasource_model_instance = self._convert_to_model(datasource)
processed_extras = self._process_extras(extras)
result_type = kwargs.setdefault("result_type", parent_result_type)
row_limit = self._process_row_limit(row_limit, result_type)
from_dttm, to_dttm = self._get_dttms(time_range, time_shift, processed_extras)
kwargs["from_dttm"] = from_dttm
kwargs["to_dttm"] = to_dttm
return QueryObject(
datasource=datasource_model_instance,
extras=extras,
row_limit=row_limit,
time_range=time_range,
time_shift=time_shift,
**kwargs,
)

def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return ConnectorRegistry.get_datasource(
str(datasource["type"]), int(datasource["id"]), db.session
)

def _process_extras(self, extras: Optional[Dict[str, Any]]) -> Dict[str, Any]:
extras = extras or {}
if config["SIP_15_ENABLED"]:
extras["time_range_endpoints"] = get_time_range_endpoints(form_data=extras)
return extras

def _process_row_limit(
self, row_limit: Optional[int], result_type: ChartDataResultType
) -> int:
default_row_limit = (
config["SAMPLES_ROW_LIMIT"]
if result_type == ChartDataResultType.SAMPLES
else config["ROW_LIMIT"]
)
return apply_max_row_limit(row_limit or default_row_limit)

def _get_dttms(
self,
time_range: Optional[str],
time_shift: Optional[str],
extras: Dict[str, Any],
) -> Tuple[Optional[datetime], Optional[datetime]]:
return get_since_until(
relative_start=extras.get(
"relative_start", config["DEFAULT_RELATIVE_START_TIME"]
),
relative_end=extras.get(
"relative_end", config["DEFAULT_RELATIVE_END_TIME"]
),
time_range=time_range,
time_shift=time_shift,
)
1 change: 1 addition & 0 deletions tests/integration_tests/query_context_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def get_sql_text(payload: Dict[str, Any]) -> str:


class TestQueryContext(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_schema_deserialization(self):
"""
Ensure that the deserialized QueryContext contains all required fields.
Expand Down

0 comments on commit 8dd6248

Please sign in to comment.