diff --git a/superset/utils/sqllab_execution_context.py b/superset/utils/sqllab_execution_context.py index 069449881ee98..6c5532b348b3e 100644 --- a/superset/utils/sqllab_execution_context.py +++ b/superset/utils/sqllab_execution_context.py @@ -56,6 +56,8 @@ class SqlJsonExecutionContext: # pylint: disable=too-many-instance-attributes expand_data: bool create_table_as_select: Optional[CreateTableAsSelect] database: Optional[Database] + query: Query + _sql_result: Optional[SqlResults] def __init__(self, query_params: Dict[str, Any]): self.create_table_as_select = None @@ -64,6 +66,9 @@ def __init__(self, query_params: Dict[str, Any]): self.user_id = self._get_user_id() self.client_id_or_short_id = cast(str, self.client_id or utils.shortid()[:10]) + def set_query(self, query: Query) -> None: + self.query = query + def _init_from_query_params(self, query_params: Dict[str, Any]) -> None: self.database_id = cast(int, query_params.get("database_id")) self.schema = cast(str, query_params.get("schema")) @@ -134,6 +139,12 @@ def _validate_db(self, database: Database) -> None: # TODO validate db.id is equal to self.database_id pass + def get_execution_result(self) -> Optional[SqlResults]: + return self._sql_result + + def set_execution_result(self, sql_result: Optional[SqlResults]) -> None: + self._sql_result = sql_result + def create_query(self) -> Query: # pylint: disable=line-too-long start_time = now_as_float() diff --git a/superset/views/core.py b/superset/views/core.py index a1bb3e9dc144f..34c1f6a829f49 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -22,7 +22,8 @@ import re from contextlib import closing from datetime import datetime, timedelta -from typing import Any, Callable, cast, Dict, List, Optional, Tuple, Union +from enum import Enum +from typing import Any, Callable, cast, Dict, List, Optional, Union from urllib import parse import backoff @@ -176,7 +177,7 @@ "your query again." ) -SqlResults = Optional[Dict[str, Any]] +SqlResults = Dict[str, Any] class Superset(BaseSupersetView): # pylint: disable=too-many-public-methods @@ -2423,13 +2424,16 @@ def sql_json(self) -> FlaskResponse: "user_agent": cast(Optional[str], request.headers.get("USER_AGENT")) } execution_context = SqlJsonExecutionContext(request.json) - return json_success(*self.sql_json_exec(execution_context, log_params)) + status: SqlJsonExecutionStatus = self.sql_json_exec( + execution_context, log_params + ) + return self._create_response_from_execution_context(execution_context, status) def sql_json_exec( # pylint: disable=too-many-statements,useless-suppression self, execution_context: SqlJsonExecutionContext, log_params: Optional[Dict[str, Any]] = None, - ) -> Tuple[str, int]: + ) -> SqlJsonExecutionStatus: """Runs arbitrary sql and returns data as json""" session = db.session() @@ -2437,7 +2441,8 @@ def sql_json_exec( # pylint: disable=too-many-statements,useless-suppression query = self._get_existing_query(execution_context, session) if self.is_query_handled(query): - return self._convert_query_to_payload(cast(Query, query)), 200 + execution_context.set_query(query) # type: ignore + return SqlJsonExecutionStatus.QUERY_ALREADY_CREATED return self._run_sql_json_exec_from_scratch( execution_context, session, log_params @@ -2466,34 +2471,32 @@ def is_query_handled(cls, query: Optional[Query]) -> bool: QueryStatus.TIMED_OUT, ] - @staticmethod - def _convert_query_to_payload(query: Query) -> str: - return json.dumps( - {"query": query.to_dict()}, - default=utils.json_int_dttm_ser, - ignore_nan=True, - ) - def _run_sql_json_exec_from_scratch( self, execution_context: SqlJsonExecutionContext, session: Session, log_params: Optional[Dict[str, Any]] = None, - ) -> Tuple[str, int]: + ) -> SqlJsonExecutionStatus: execution_context.set_database( self._get_the_query_db(execution_context, session) ) query = execution_context.create_query() - self._save_new_query(query, session) - logger.info("Triggering query_id: %i", query.id) - self._validate_access(query, session) - rendered_query = self._render_query(query, execution_context, session) + try: + self._save_new_query(query, session) + logger.info("Triggering query_id: %i", query.id) + self._validate_access(query, session) + execution_context.set_query(query) + rendered_query = self._render_query(execution_context) - self._set_query_limit_if_required(execution_context, query, rendered_query) + self._set_query_limit_if_required(execution_context, rendered_query) - return self._execute_query( - query, execution_context, rendered_query, session, log_params - ) + return self._execute_query( + execution_context, rendered_query, session, log_params + ) + except Exception as ex: + query.status = QueryStatus.FAILED + session.commit() + raise ex @classmethod def _get_the_query_db( @@ -2544,7 +2547,7 @@ def _validate_access( # pylint: disable=no-self-use raise SupersetErrorException(ex.error, status=403) from ex def _render_query( # pylint: disable=no-self-use - self, query: Query, execution_context: SqlJsonExecutionContext, session: Session + self, execution_context: SqlJsonExecutionContext ) -> str: def validate( rendered_query: str, template_processor: BaseTemplateProcessor @@ -2554,8 +2557,6 @@ def validate( ast = template_processor._env.parse(rendered_query) undefined_parameters = find_undeclared_variables(ast) # type: ignore if undefined_parameters: - query.status = QueryStatus.FAILED - session.commit() raise SupersetTemplateParamsErrorException( message=ngettext( "The parameter %(parameters)s in your query is undefined.", @@ -2572,6 +2573,8 @@ def validate( }, ) + query = execution_context.query + try: template_processor = get_template_processor( database=query.database, query=query @@ -2581,8 +2584,6 @@ def validate( ) validate(rendered_query, template_processor) except TemplateError as ex: - query.status = QueryStatus.FAILED - session.commit() raise SupersetTemplateParamsErrorException( message=__( 'The query contains one or more malformed template parameters. Please check your query and confirm that all template parameters are surround by double braces, for example, "{{ ds }}". Then, try running your query again.' @@ -2593,13 +2594,10 @@ def validate( return rendered_query def _set_query_limit_if_required( - self, - execution_context: SqlJsonExecutionContext, - query: Query, - rendered_query: str, + self, execution_context: SqlJsonExecutionContext, rendered_query: str, ) -> None: if self._is_required_to_set_limit(execution_context): - self._set_query_limit(rendered_query, query, execution_context) + self._set_query_limit(rendered_query, execution_context) def _is_required_to_set_limit( # pylint: disable=no-self-use self, execution_context: SqlJsonExecutionContext @@ -2609,10 +2607,7 @@ def _is_required_to_set_limit( # pylint: disable=no-self-use ) def _set_query_limit( # pylint: disable=no-self-use - self, - rendered_query: str, - query: Query, - execution_context: SqlJsonExecutionContext, + self, rendered_query: str, execution_context: SqlJsonExecutionContext, ) -> None: db_engine_spec = execution_context.database.db_engine_spec # type: ignore limits = [ @@ -2620,49 +2615,40 @@ def _set_query_limit( # pylint: disable=no-self-use execution_context.limit, ] if limits[0] is None or limits[0] > limits[1]: # type: ignore - query.limiting_factor = LimitingFactor.DROPDOWN + execution_context.query.limiting_factor = LimitingFactor.DROPDOWN elif limits[1] > limits[0]: # type: ignore - query.limiting_factor = LimitingFactor.QUERY + execution_context.query.limiting_factor = LimitingFactor.QUERY else: # limits[0] == limits[1] - query.limiting_factor = LimitingFactor.QUERY_AND_DROPDOWN - query.limit = min(lim for lim in limits if lim is not None) + execution_context.query.limiting_factor = LimitingFactor.QUERY_AND_DROPDOWN + execution_context.query.limit = min(lim for lim in limits if lim is not None) - def _execute_query( # pylint: disable=too-many-arguments + def _execute_query( self, - query: Query, execution_context: SqlJsonExecutionContext, rendered_query: str, session: Session, log_params: Optional[Dict[str, Any]], - ) -> Tuple[str, int]: + ) -> SqlJsonExecutionStatus: # Flag for whether or not to expand data # (feature that will expand Presto row objects and arrays) - expand_data: bool = execution_context.expand_data # Async request. if execution_context.is_run_asynchronous(): - return ( - self._sql_json_async( - query, rendered_query, expand_data, session, log_params - ), - 202, + return self._sql_json_async( + execution_context, rendered_query, session, log_params ) - # Sync request. - return ( - self._sql_json_sync( - query, rendered_query, expand_data, session, log_params - ), - 200, + + return self._sql_json_sync( + execution_context, rendered_query, session, log_params ) @classmethod - def _sql_json_async( # pylint: disable=too-many-arguments + def _sql_json_async( cls, - query: Query, + execution_context: SqlJsonExecutionContext, rendered_query: str, - expand_data: bool, session: Session, log_params: Optional[Dict[str, Any]], - ) -> str: + ) -> SqlJsonExecutionStatus: """ Send SQL JSON query to celery workers. @@ -2671,6 +2657,7 @@ def _sql_json_async( # pylint: disable=too-many-arguments :param query: The query (SQLAlchemy) object :return: A Flask Response """ + query = execution_context.query logger.info("Query %i: Running query on a Celery worker", query.id) # Ignore the celery future object and the request may time out. query_id = query.id @@ -2684,7 +2671,7 @@ def _sql_json_async( # pylint: disable=too-many-arguments if g.user and hasattr(g.user, "username") else None, start_time=now_as_float(), - expand_data=expand_data, + expand_data=execution_context.expand_data, log_params=log_params, ) @@ -2719,17 +2706,16 @@ def _sql_json_async( # pylint: disable=too-many-arguments QueryDAO.update_saved_query_exec_info(query_id) session.commit() - return cls._convert_query_to_payload(query) + return SqlJsonExecutionStatus.QUERY_IS_RUNNING @classmethod def _sql_json_sync( cls, - query: Query, + execution_context: SqlJsonExecutionContext, rendered_query: str, - expand_data: bool, _session: Session, log_params: Optional[Dict[str, Any]], - ) -> str: + ) -> SqlJsonExecutionStatus: """ Execute SQL query (sql json). @@ -2738,18 +2724,22 @@ def _sql_json_sync( :return: A Flask Response :raises: SupersetTimeoutException """ + query = execution_context.query try: timeout = config["SQLLAB_TIMEOUT"] timeout_msg = f"The query exceeded the {timeout} seconds timeout." query_id = query.id data = cls._get_sql_results_with_timeout( - query, timeout, rendered_query, expand_data, timeout_msg, log_params + query, + timeout, + rendered_query, + execution_context.expand_data, + timeout_msg, + log_params, ) # Update saved query if needed QueryDAO.update_saved_query_exec_info(query_id) - - # TODO: set LimitingFactor to display? - payload = cls._convert_sql_result_to_payload(data) + execution_context.set_execution_result(data) except SupersetTimeoutException as ex: # re-raise exception for api exception handler raise ex @@ -2767,8 +2757,7 @@ def _sql_json_sync( ) # old string-only error message raise SupersetGenericDBErrorException(data["error"]) - - return payload + return SqlJsonExecutionStatus.HAS_RESULTS @classmethod def _get_sql_results_with_timeout( # pylint: disable=too-many-arguments @@ -2779,7 +2768,7 @@ def _get_sql_results_with_timeout( # pylint: disable=too-many-arguments expand_data: bool, timeout_msg: str, log_params: Optional[Dict[str, Any]], - ) -> SqlResults: + ) -> Optional[SqlResults]: with utils.timeout(seconds=timeout, error_message=timeout_msg): # pylint: disable=no-value-for-parameter return sql_lab.get_sql_results( @@ -2800,14 +2789,38 @@ def _is_store_results(cls, query: Query) -> bool: is_feature_enabled("SQLLAB_BACKEND_PERSISTENCE") and not query.select_as_cta ) - @classmethod - def _convert_sql_result_to_payload(cls, sql_results: SqlResults) -> str: - return json.dumps( - apply_display_max_row_limit(sql_results), # type: ignore - default=utils.pessimistic_json_iso_dttm_ser, - ignore_nan=True, - encoding=None, - ) + def _create_response_from_execution_context( + # pylint: disable=invalid-name, no-self-use + self, + execution_context: SqlJsonExecutionContext, + status: SqlJsonExecutionStatus, + ) -> FlaskResponse: + def _to_payload_results_based(execution_result: SqlResults) -> str: + display_max_row = config["DISPLAY_MAX_ROW"] + return json.dumps( + apply_display_max_row_limit(execution_result, display_max_row), + default=utils.pessimistic_json_iso_dttm_ser, + ignore_nan=True, + encoding=None, + ) + + def _to_payload_query_based(query: Query) -> str: + return json.dumps( + {"query": query.to_dict()}, + default=utils.json_int_dttm_ser, + ignore_nan=True, + ) + + status_code = 200 + if status == SqlJsonExecutionStatus.HAS_RESULTS: + payload = _to_payload_results_based( + execution_context.get_execution_result() or {} + ) + else: + payload = _to_payload_query_based(execution_context.query) + if status.QUERY_IS_RUNNING: + status_code = 202 + return json_success(payload, status_code) @has_access @event_logger.log_this @@ -3177,3 +3190,10 @@ def schemas_access_for_file_upload(self) -> FlaskResponse: "Failed to fetch schemas allowed for csv upload in this database! " "Please contact your Superset Admin!" ) + + +class SqlJsonExecutionStatus(Enum): + QUERY_ALREADY_CREATED = 1 + HAS_RESULTS = 2 + QUERY_IS_RUNNING = 3 + FAILED = 4