Skip to content

Commit

Permalink
refactor: sql_json view endpoint: use execution context instead of qu…
Browse files Browse the repository at this point in the history
…ery (apache#16677)

* refactor sql_json view endpoint: use execution context instead of query

* fix failed tests

* fix failed tests

* refactor renaming enum options
  • Loading branch information
ofekisr authored Sep 13, 2021
1 parent 30f4351 commit 521b81a
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 79 deletions.
11 changes: 11 additions & 0 deletions superset/utils/sqllab_execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes
expand_data: bool
create_table_as_select: Optional[CreateTableAsSelect]
database: Optional[Database]
query: Query
_sql_result: Optional[SqlResults]

def __init__(self, query_params: Dict[str, Any]):
self.create_table_as_select = None
Expand All @@ -64,6 +66,9 @@ def __init__(self, query_params: Dict[str, Any]):
self.user_id = self._get_user_id()
self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10])

def set_query(self, query: Query) -> None:
self.query = query

def _init_from_query_params(self, query_params: Dict[str, Any]) -> None:
self.database_id = cast(int, query_params.get("database_id"))
self.schema = cast(str, query_params.get("schema"))
Expand Down Expand Up @@ -134,6 +139,12 @@ def _validate_db(self, database: Database) -> None:
# TODO validate db.id is equal to self.database_id
pass

def get_execution_result(self) -> Optional[SqlResults]:
return self._sql_result

def set_execution_result(self, sql_result: Optional[SqlResults]) -> None:
self._sql_result = sql_result

def create_query(self) -> Query:
# pylint: disable=line-too-long
start_time = now_as_float()
Expand Down
178 changes: 99 additions & 79 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@
import re
from contextlib import closing
from datetime import datetime, timedelta
from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union
from enum import Enum
from typing import Any, Callable, cast, Dict, List, Optional, Union
from urllib import parse

import backoff
Expand Down Expand Up @@ -176,7 +177,7 @@
"your query again."
)

SqlResults = Optional[Dict[str, Any]]
SqlResults = Dict[str, Any]


class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods
Expand Down Expand Up @@ -2423,21 +2424,25 @@ def sql_json(self) -> FlaskResponse:
"user_agent": cast(Optional[str], request.headers.get("USER_AGENT"))
}
execution_context = SqlJsonExecutionContext(request.json)
return json_success(*self.sql_json_exec(execution_context, log_params))
status: SqlJsonExecutionStatus = self.sql_json_exec(
execution_context, log_params
)
return self._create_response_from_execution_context(execution_context, status)

def sql_json_exec( # pylint: disable=too-many-statements,useless-suppression
self,
execution_context: SqlJsonExecutionContext,
log_params: Optional[Dict[str, Any]] = None,
) -> Tuple[str, int]:
) -> SqlJsonExecutionStatus:
"""Runs arbitrary sql and returns data as json"""

session = db.session()

query = self._get_existing_query(execution_context, session)

if self.is_query_handled(query):
return self._convert_query_to_payload(cast(Query, query)), 200
execution_context.set_query(query) # type: ignore
return SqlJsonExecutionStatus.QUERY_ALREADY_CREATED

return self._run_sql_json_exec_from_scratch(
execution_context, session, log_params
Expand Down Expand Up @@ -2466,34 +2471,32 @@ def is_query_handled(cls, query: Optional[Query]) -> bool:
QueryStatus.TIMED_OUT,
]

@staticmethod
def _convert_query_to_payload(query: Query) -> str:
return json.dumps(
{"query": query.to_dict()},
default=utils.json_int_dttm_ser,
ignore_nan=True,
)

def _run_sql_json_exec_from_scratch(
self,
execution_context: SqlJsonExecutionContext,
session: Session,
log_params: Optional[Dict[str, Any]] = None,
) -> Tuple[str, int]:
) -> SqlJsonExecutionStatus:
execution_context.set_database(
self._get_the_query_db(execution_context, session)
)
query = execution_context.create_query()
self._save_new_query(query, session)
logger.info("Triggering query_id: %i", query.id)
self._validate_access(query, session)
rendered_query = self._render_query(query, execution_context, session)
try:
self._save_new_query(query, session)
logger.info("Triggering query_id: %i", query.id)
self._validate_access(query, session)
execution_context.set_query(query)
rendered_query = self._render_query(execution_context)

self._set_query_limit_if_required(execution_context, query, rendered_query)
self._set_query_limit_if_required(execution_context, rendered_query)

return self._execute_query(
query, execution_context, rendered_query, session, log_params
)
return self._execute_query(
execution_context, rendered_query, session, log_params
)
except Exception as ex:
query.status = QueryStatus.FAILED
session.commit()
raise ex

@classmethod
def _get_the_query_db(
Expand Down Expand Up @@ -2544,7 +2547,7 @@ def _validate_access( # pylint: disable=no-self-use
raise SupersetErrorException(ex.error, status=403) from ex

def _render_query( # pylint: disable=no-self-use
self, query: Query, execution_context: SqlJsonExecutionContext, session: Session
self, execution_context: SqlJsonExecutionContext
) -> str:
def validate(
rendered_query: str, template_processor: BaseTemplateProcessor
Expand All @@ -2554,8 +2557,6 @@ def validate(
ast = template_processor._env.parse(rendered_query)
undefined_parameters = find_undeclared_variables(ast) # type: ignore
if undefined_parameters:
query.status = QueryStatus.FAILED
session.commit()
raise SupersetTemplateParamsErrorException(
message=ngettext(
"The parameter %(parameters)s in your query is undefined.",
Expand All @@ -2572,6 +2573,8 @@ def validate(
},
)

query = execution_context.query

try:
template_processor = get_template_processor(
database=query.database, query=query
Expand All @@ -2581,8 +2584,6 @@ def validate(
)
validate(rendered_query, template_processor)
except TemplateError as ex:
query.status = QueryStatus.FAILED
session.commit()
raise SupersetTemplateParamsErrorException(
message=__(
'The query contains one or more malformed template parameters. Please check your query and confirm that all template parameters are surround by double braces, for example, "{{ ds }}". Then, try running your query again.'
Expand All @@ -2593,13 +2594,10 @@ def validate(
return rendered_query

def _set_query_limit_if_required(
self,
execution_context: SqlJsonExecutionContext,
query: Query,
rendered_query: str,
self, execution_context: SqlJsonExecutionContext, rendered_query: str,
) -> None:
if self._is_required_to_set_limit(execution_context):
self._set_query_limit(rendered_query, query, execution_context)
self._set_query_limit(rendered_query, execution_context)

def _is_required_to_set_limit( # pylint: disable=no-self-use
self, execution_context: SqlJsonExecutionContext
Expand All @@ -2609,60 +2607,48 @@ def _is_required_to_set_limit( # pylint: disable=no-self-use
)

def _set_query_limit( # pylint: disable=no-self-use
self,
rendered_query: str,
query: Query,
execution_context: SqlJsonExecutionContext,
self, rendered_query: str, execution_context: SqlJsonExecutionContext,
) -> None:
db_engine_spec = execution_context.database.db_engine_spec # type: ignore
limits = [
db_engine_spec.get_limit_from_sql(rendered_query),
execution_context.limit,
]
if limits[0] is None or limits[0] > limits[1]: # type: ignore
query.limiting_factor = LimitingFactor.DROPDOWN
execution_context.query.limiting_factor = LimitingFactor.DROPDOWN
elif limits[1] > limits[0]: # type: ignore
query.limiting_factor = LimitingFactor.QUERY
execution_context.query.limiting_factor = LimitingFactor.QUERY
else: # limits[0] == limits[1]
query.limiting_factor = LimitingFactor.QUERY_AND_DROPDOWN
query.limit = min(lim for lim in limits if lim is not None)
execution_context.query.limiting_factor = LimitingFactor.QUERY_AND_DROPDOWN
execution_context.query.limit = min(lim for lim in limits if lim is not None)

def _execute_query( # pylint: disable=too-many-arguments
def _execute_query(
self,
query: Query,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
session: Session,
log_params: Optional[Dict[str, Any]],
) -> Tuple[str, int]:
) -> SqlJsonExecutionStatus:
# Flag for whether or not to expand data
# (feature that will expand Presto row objects and arrays)
expand_data: bool = execution_context.expand_data
# Async request.
if execution_context.is_run_asynchronous():
return (
self._sql_json_async(
query, rendered_query, expand_data, session, log_params
),
202,
return self._sql_json_async(
execution_context, rendered_query, session, log_params
)
# Sync request.
return (
self._sql_json_sync(
query, rendered_query, expand_data, session, log_params
),
200,

return self._sql_json_sync(
execution_context, rendered_query, session, log_params
)

@classmethod
def _sql_json_async( # pylint: disable=too-many-arguments
def _sql_json_async(
cls,
query: Query,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
expand_data: bool,
session: Session,
log_params: Optional[Dict[str, Any]],
) -> str:
) -> SqlJsonExecutionStatus:
"""
Send SQL JSON query to celery workers.
Expand All @@ -2671,6 +2657,7 @@ def _sql_json_async( # pylint: disable=too-many-arguments
:param query: The query (SQLAlchemy) object
:return: A Flask Response
"""
query = execution_context.query
logger.info("Query %i: Running query on a Celery worker", query.id)
# Ignore the celery future object and the request may time out.
query_id = query.id
Expand All @@ -2684,7 +2671,7 @@ def _sql_json_async( # pylint: disable=too-many-arguments
if g.user and hasattr(g.user, "username")
else None,
start_time=now_as_float(),
expand_data=expand_data,
expand_data=execution_context.expand_data,
log_params=log_params,
)

Expand Down Expand Up @@ -2719,17 +2706,16 @@ def _sql_json_async( # pylint: disable=too-many-arguments
QueryDAO.update_saved_query_exec_info(query_id)

session.commit()
return cls._convert_query_to_payload(query)
return SqlJsonExecutionStatus.QUERY_IS_RUNNING

@classmethod
def _sql_json_sync(
cls,
query: Query,
execution_context: SqlJsonExecutionContext,
rendered_query: str,
expand_data: bool,
_session: Session,
log_params: Optional[Dict[str, Any]],
) -> str:
) -> SqlJsonExecutionStatus:
"""
Execute SQL query (sql json).
Expand All @@ -2738,18 +2724,22 @@ def _sql_json_sync(
:return: A Flask Response
:raises: SupersetTimeoutException
"""
query = execution_context.query
try:
timeout = config["SQLLAB_TIMEOUT"]
timeout_msg = f"The query exceeded the {timeout} seconds timeout."
query_id = query.id
data = cls._get_sql_results_with_timeout(
query, timeout, rendered_query, expand_data, timeout_msg, log_params
query,
timeout,
rendered_query,
execution_context.expand_data,
timeout_msg,
log_params,
)
# Update saved query if needed
QueryDAO.update_saved_query_exec_info(query_id)

# TODO: set LimitingFactor to display?
payload = cls._convert_sql_result_to_payload(data)
execution_context.set_execution_result(data)
except SupersetTimeoutException as ex:
# re-raise exception for api exception handler
raise ex
Expand All @@ -2767,8 +2757,7 @@ def _sql_json_sync(
)
# old string-only error message
raise SupersetGenericDBErrorException(data["error"])

return payload
return SqlJsonExecutionStatus.HAS_RESULTS

@classmethod
def _get_sql_results_with_timeout( # pylint: disable=too-many-arguments
Expand All @@ -2779,7 +2768,7 @@ def _get_sql_results_with_timeout( # pylint: disable=too-many-arguments
expand_data: bool,
timeout_msg: str,
log_params: Optional[Dict[str, Any]],
) -> SqlResults:
) -> Optional[SqlResults]:
with utils.timeout(seconds=timeout, error_message=timeout_msg):
# pylint: disable=no-value-for-parameter
return sql_lab.get_sql_results(
Expand All @@ -2800,14 +2789,38 @@ def _is_store_results(cls, query: Query) -> bool:
is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE") and not query.select_as_cta
)

@classmethod
def _convert_sql_result_to_payload(cls, sql_results: SqlResults) -> str:
return json.dumps(
apply_display_max_row_limit(sql_results), # type: ignore
default=utils.pessimistic_json_iso_dttm_ser,
ignore_nan=True,
encoding=None,
)
def _create_response_from_execution_context(
# pylint: disable=invalid-name, no-self-use
self,
execution_context: SqlJsonExecutionContext,
status: SqlJsonExecutionStatus,
) -> FlaskResponse:
def _to_payload_results_based(execution_result: SqlResults) -> str:
display_max_row = config["DISPLAY_MAX_ROW"]
return json.dumps(
apply_display_max_row_limit(execution_result, display_max_row),
default=utils.pessimistic_json_iso_dttm_ser,
ignore_nan=True,
encoding=None,
)

def _to_payload_query_based(query: Query) -> str:
return json.dumps(
{"query": query.to_dict()},
default=utils.json_int_dttm_ser,
ignore_nan=True,
)

status_code = 200
if status == SqlJsonExecutionStatus.HAS_RESULTS:
payload = _to_payload_results_based(
execution_context.get_execution_result() or {}
)
else:
payload = _to_payload_query_based(execution_context.query)
if status.QUERY_IS_RUNNING:
status_code = 202
return json_success(payload, status_code)

@has_access
@event_logger.log_this
Expand Down Expand Up @@ -3177,3 +3190,10 @@ def schemas_access_for_file_upload(self) -> FlaskResponse:
"Failed to fetch schemas allowed for csv upload in this database! "
"Please contact your Superset Admin!"
)


class SqlJsonExecutionStatus(Enum):
QUERY_ALREADY_CREATED = 1
HAS_RESULTS = 2
QUERY_IS_RUNNING = 3
FAILED = 4

0 comments on commit 521b81a

Please sign in to comment.