diff --git a/docs/docs/contributing/testing-locally.mdx b/docs/docs/contributing/testing-locally.mdx index 8cf1effb5d938..70ff9f475b75c 100644 --- a/docs/docs/contributing/testing-locally.mdx +++ b/docs/docs/contributing/testing-locally.mdx @@ -83,9 +83,20 @@ To run a single test file: npm run test -- path/to/file.js ``` -### Integration Testing +### e2e Integration Testing -We use [Cypress](https://www.cypress.io/) for integration tests. Tests can be run by `tox -e cypress`. To open Cypress and explore tests first setup and run test server: +We use [Cypress](https://www.cypress.io/) for end-to-end integration +tests. One easy option to get started quickly is to leverage `tox` to +run the whole suite in an isolated environment. + +```bash +tox -e cypress +``` + +Alternatively, you can go lower level and set things up in your +development environment by following these steps: + +First set up a python/flask backend: ```bash export SUPERSET_CONFIG=tests.integration_tests.superset_test_config @@ -98,7 +109,7 @@ superset load-examples --load-test-data superset run --port 8081 ``` -Run Cypress tests: +In another terminal, prepare the frontend and run Cypress tests: ```bash cd superset-frontend diff --git a/pyproject.toml b/pyproject.toml index da7dbcfed7756..fd7dc98c86c73 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -267,6 +267,7 @@ usedevelop = true allowlist_externals = npm pkill + {toxinidir}/superset-frontend/cypress_build.sh [testenv:cypress] setenv = diff --git a/superset/config.py b/superset/config.py index 787d52fa1453a..8e0db72b28dce 100644 --- a/superset/config.py +++ b/superset/config.py @@ -74,6 +74,9 @@ # Realtime stats logger, a StatsD implementation exists STATS_LOGGER = DummyStatsLogger() + +# By default will log events to the metadata database with `DBEventLogger` +# Note that you can use `StdOutEventLogger` for debugging EVENT_LOGGER = DBEventLogger() SUPERSET_LOG_VIEW = True diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index afd2791c9c601..208487299a969 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1398,20 +1398,6 @@ def get_fetch_values_predicate( ) ) from ex - def mutate_query_from_config(self, sql: str) -> str: - """Apply config's SQL_QUERY_MUTATOR - - Typically adds comments to the query with context""" - sql_query_mutator = config["SQL_QUERY_MUTATOR"] - mutate_after_split = config["MUTATE_AFTER_SPLIT"] - if sql_query_mutator and not mutate_after_split: - sql = sql_query_mutator( - sql, - security_manager=security_manager, - database=self.database, - ) - return sql - def get_template_processor(self, **kwargs: Any) -> BaseTemplateProcessor: return get_template_processor(table=self, database=self.database, **kwargs) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 451e96927d9fc..0a567102383ab 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -59,7 +59,7 @@ from sqlalchemy.types import TypeEngine from sqlparse.tokens import CTE -from superset import security_manager, sql_parse +from superset import sql_parse from superset.constants import TimeGrain as TimeGrainConstants from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -1682,16 +1682,8 @@ def process_statement(cls, statement: str, database: Database) -> str: """ parsed_query = ParsedQuery(statement, engine=cls.engine) sql = parsed_query.stripped() - sql_query_mutator = current_app.config["SQL_QUERY_MUTATOR"] - mutate_after_split = current_app.config["MUTATE_AFTER_SPLIT"] - if sql_query_mutator and not mutate_after_split: - sql = sql_query_mutator( - sql, - security_manager=security_manager, - database=database, - ) - return sql + return database.mutate_sql_based_on_config(sql, is_split=True) @classmethod def estimate_query_cost( diff --git a/superset/models/core.py b/superset/models/core.py index bfd4c39593392..4c514f3086882 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -66,6 +66,7 @@ from superset.extensions import ( cache_manager, encrypted_field_factory, + event_logger, security_manager, ssh_manager_factory, ) @@ -564,6 +565,20 @@ def get_default_schema_for_query(self, query: Query) -> str | None: """ return self.db_engine_spec.get_default_schema_for_query(self, query) + @staticmethod + def post_process_df(df: pd.DataFrame) -> pd.DataFrame: + def column_needs_conversion(df_series: pd.Series) -> bool: + return ( + not df_series.empty + and isinstance(df_series, pd.Series) + and isinstance(df_series[0], (list, dict)) + ) + + for col, coltype in df.dtypes.to_dict().items(): + if coltype == numpy.object_ and column_needs_conversion(df[col]): + df[col] = df[col].apply(utils.json_dumps_w_dates) + return df + @property def quote_identifier(self) -> Callable[[str], str]: """Add quotes to potential identifier expressions if needed""" @@ -572,7 +587,27 @@ def quote_identifier(self) -> Callable[[str], str]: def get_reserved_words(self) -> set[str]: return self.get_dialect().preparer.reserved_words - def get_df( # pylint: disable=too-many-locals + def mutate_sql_based_on_config(self, sql_: str, is_split: bool = False) -> str: + """ + Mutates the SQL query based on the app configuration. + + Two config params here affect the behavior of the SQL query mutator: + - `SQL_QUERY_MUTATOR`: A user-provided function that mutates the SQL query. + - `MUTATE_AFTER_SPLIT`: If True, the SQL query mutator is only called after the + sql is broken down into smaller queries. If False, the SQL query mutator applies + on the group of queries as a whole. Here the called passes the context + as to whether the SQL is split or already. + """ + sql_mutator = config["SQL_QUERY_MUTATOR"] + if sql_mutator and (is_split == config["MUTATE_AFTER_SPLIT"]): + return sql_mutator( + sql_, + security_manager=security_manager, + database=self, + ) + return sql_ + + def get_df( self, sql: str, schema: str | None = None, @@ -581,15 +616,6 @@ def get_df( # pylint: disable=too-many-locals sqls = self.db_engine_spec.parse_sql(sql) with self.get_sqla_engine(schema) as engine: engine_url = engine.url - mutate_after_split = config["MUTATE_AFTER_SPLIT"] - sql_query_mutator = config["SQL_QUERY_MUTATOR"] - - def needs_conversion(df_series: pd.Series) -> bool: - return ( - not df_series.empty - and isinstance(df_series, pd.Series) - and isinstance(df_series[0], (list, dict)) - ) def _log_query(sql: str) -> None: if log_query: @@ -603,42 +629,30 @@ def _log_query(sql: str) -> None: with self.get_raw_connection(schema=schema) as conn: cursor = conn.cursor() - for sql_ in sqls[:-1]: - if mutate_after_split: - sql_ = sql_query_mutator( - sql_, - security_manager=security_manager, - database=None, - ) + df = None + for i, sql_ in enumerate(sqls): + sql_ = self.mutate_sql_based_on_config(sql_, is_split=True) _log_query(sql_) - self.db_engine_spec.execute(cursor, sql_, self) - cursor.fetchall() - - if mutate_after_split: - last_sql = sql_query_mutator( - sqls[-1], - security_manager=security_manager, - database=None, - ) - _log_query(last_sql) - self.db_engine_spec.execute(cursor, last_sql, self) - else: - _log_query(sqls[-1]) - self.db_engine_spec.execute(cursor, sqls[-1], self) - - data = self.db_engine_spec.fetch_data(cursor) - result_set = SupersetResultSet( - data, cursor.description, self.db_engine_spec - ) - df = result_set.to_pandas_df() + with event_logger.log_context( + action="execute_sql", + database=self, + object_ref=__name__, + ): + self.db_engine_spec.execute(cursor, sql_, self) + if i < len(sqls) - 1: + # If it's not the last, we don't keep the results + cursor.fetchall() + else: + # Last query, fetch and process the results + data = self.db_engine_spec.fetch_data(cursor) + result_set = SupersetResultSet( + data, cursor.description, self.db_engine_spec + ) + df = result_set.to_pandas_df() if mutator: df = mutator(df) - for col, coltype in df.dtypes.to_dict().items(): - if coltype == numpy.object_ and needs_conversion(df[col]): - df[col] = df[col].apply(utils.json_dumps_w_dates) - - return df + return self.post_process_df(df) def compile_sqla_query(self, qry: Select, schema: str | None = None) -> str: with self.get_sqla_engine(schema) as engine: diff --git a/superset/models/helpers.py b/superset/models/helpers.py index ad90e664ba9ad..7d82b6d08203e 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -880,18 +880,6 @@ def make_sqla_column_compatible( sqla_col.key = label_expected return sqla_col - def mutate_query_from_config(self, sql: str) -> str: - """Apply config's SQL_QUERY_MUTATOR - - Typically adds comments to the query with context""" - if sql_query_mutator := config["SQL_QUERY_MUTATOR"]: - sql = sql_query_mutator( - sql, - security_manager=security_manager, - database=self.database, - ) - return sql - @staticmethod def _apply_cte(sql: str, cte: Optional[str]) -> str: """ @@ -919,7 +907,7 @@ def get_query_str_extended( logger.warning("Unable to parse SQL to format it, passing it as-is") if mutate: - sql = self.mutate_query_from_config(sql) + sql = self.database.mutate_sql_based_on_config(sql) return QueryStringExtended( applied_template_filters=sqlaq.applied_template_filters, applied_filter_columns=sqlaq.applied_filter_columns, @@ -1393,7 +1381,7 @@ def values_for_column( with self.database.get_sqla_engine() as engine: sql = qry.compile(engine, compile_kwargs={"literal_binds": True}) sql = self._apply_cte(sql, cte) - sql = self.mutate_query_from_config(sql) + sql = self.database.mutate_sql_based_on_config(sql) df = pd.read_sql_query(sql=sql, con=engine) # replace NaN with None to ensure it can be serialized to JSON diff --git a/superset/sql_lab.py b/superset/sql_lab.py index e87ae9c5b7264..9076136c64f81 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -46,7 +46,7 @@ SupersetErrorException, SupersetErrorsException, ) -from superset.extensions import celery_app +from superset.extensions import celery_app, event_logger from superset.models.core import Database from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet @@ -73,7 +73,6 @@ SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60 SQL_MAX_ROW = config["SQL_MAX_ROW"] SQLLAB_CTAS_NO_LIMIT = config["SQLLAB_CTAS_NO_LIMIT"] -SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"] log_query = config["QUERY_LOGGER"] logger = logging.getLogger(__name__) @@ -264,11 +263,7 @@ def execute_sql_statement( # pylint: disable=too-many-statements sql = apply_limit_if_exists(database, increased_limit, query, sql) # Hook to allow environment-specific mutation (usually comments) to the SQL - sql = SQL_QUERY_MUTATOR( - sql, - security_manager=security_manager, - database=database, - ) + sql = database.mutate_sql_based_on_config(sql) try: query.executed_sql = sql if log_query: @@ -281,21 +276,26 @@ def execute_sql_statement( # pylint: disable=too-many-statements log_params, ) db.session.commit() - with stats_timing("sqllab.query.time_executing_query", stats_logger): - db_engine_spec.execute_with_cursor(cursor, sql, query) - - with stats_timing("sqllab.query.time_fetching_results", stats_logger): - logger.debug( - "Query %d: Fetching data for query object: %s", - query.id, - str(query.to_dict()), - ) - data = db_engine_spec.fetch_data(cursor, increased_limit) - if query.limit is None or len(data) <= query.limit: - query.limiting_factor = LimitingFactor.NOT_LIMITED - else: - # return 1 row less than increased_query - data = data[:-1] + with event_logger.log_context( + action="execute_sql", + database=database, + object_ref=__name__, + ): + with stats_timing("sqllab.query.time_executing_query", stats_logger): + db_engine_spec.execute_with_cursor(cursor, sql, query) + + with stats_timing("sqllab.query.time_fetching_results", stats_logger): + logger.debug( + "Query %d: Fetching data for query object: %s", + query.id, + str(query.to_dict()), + ) + data = db_engine_spec.fetch_data(cursor, increased_limit) + if query.limit is None or len(data) <= query.limit: + query.limiting_factor = LimitingFactor.NOT_LIMITED + else: + # return 1 row less than increased_query + data = data[:-1] except SoftTimeLimitExceeded as ex: query.status = QueryStatus.TIMED_OUT diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 8e7d8c7209e0e..4d4d898034ca2 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -20,7 +20,7 @@ from contextlib import closing from typing import Any, Optional -from superset import app, security_manager +from superset import app from superset.models.core import Database from superset.sql_parse import ParsedQuery from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation @@ -54,12 +54,7 @@ def validate_statement( sql = parsed_query.stripped() # Hook to allow environment-specific mutation (usually comments) to the SQL - if sql_query_mutator := config["SQL_QUERY_MUTATOR"]: - sql = sql_query_mutator( - sql, - security_manager=security_manager, - database=database, - ) + sql = database.mutate_sql_based_on_config(sql) # Transform the final statement to an explain call before sending it on # to presto to validate diff --git a/superset/utils/core.py b/superset/utils/core.py index 988baed0af8d6..b89fd759a5811 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1907,3 +1907,10 @@ def remove_extra_adhoc_filters(form_data: dict[str, Any]) -> None: form_data[key] = [ filter_ for filter_ in value or [] if not filter_.get("isExtra") ] + + +def to_int(v: Any, value_if_invalid: int = 0) -> int: + try: + return int(v) + except (ValueError, TypeError): + return value_if_invalid diff --git a/superset/utils/log.py b/superset/utils/log.py index 1de599bf08233..0958727f983b2 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -32,7 +32,7 @@ from sqlalchemy.exc import SQLAlchemyError from superset.extensions import stats_logger_manager -from superset.utils.core import get_user_id, LoggerLevel +from superset.utils.core import get_user_id, LoggerLevel, to_int if TYPE_CHECKING: from superset.stats_logger import BaseStatsLogger @@ -52,6 +52,10 @@ def collect_request_payload() -> dict[str, Any]: **request.args.to_dict(), } + if request.is_json: + json_payload = request.get_json(cache=True, silent=True) or {} + payload.update(json_payload) + # save URL match pattern in addition to the request path url_rule = str(request.url_rule) if url_rule != request.path: @@ -130,12 +134,13 @@ def log( # pylint: disable=too-many-arguments ) -> None: pass - def log_with_context( # pylint: disable=too-many-locals + def log_with_context( # pylint: disable=too-many-locals,too-many-arguments self, action: str, duration: timedelta | None = None, object_ref: str | None = None, log_to_statsd: bool = True, + database: Any | None = None, **payload_override: dict[str, Any] | None, ) -> None: # pylint: disable=import-outside-toplevel @@ -165,11 +170,15 @@ def log_with_context( # pylint: disable=too-many-locals if payload_override: payload.update(payload_override) - dashboard_id: int | None = None - try: - dashboard_id = int(payload.get("dashboard_id")) # type: ignore - except (TypeError, ValueError): - dashboard_id = None + dashboard_id = to_int(payload.get("dashboard_id")) + + database_params = {"database_id": payload.get("database_id")} + if database and type(database).__name__ == "Database": + database_params = { + "database_id": database.id, + "engine": database.backend, + "database_driver": database.driver, + } if "form_data" in payload: form_data, _ = get_form_data() @@ -178,10 +187,7 @@ def log_with_context( # pylint: disable=too-many-locals else: slice_id = payload.get("slice_id") - try: - slice_id = int(slice_id) # type: ignore - except (TypeError, ValueError): - slice_id = 0 + slice_id = to_int(slice_id) if log_to_statsd: stats_logger_manager.instance.incr(action) @@ -201,6 +207,7 @@ def log_with_context( # pylint: disable=too-many-locals slice_id=slice_id, duration_ms=duration_ms, referrer=referrer, + **database_params, ) @contextmanager @@ -209,6 +216,7 @@ def log_context( action: str, object_ref: str | None = None, log_to_statsd: bool = True, + **kwargs: Any, ) -> Iterator[Callable[..., None]]: """ Log an event with additional information from the request context. @@ -216,7 +224,7 @@ def log_context( :param object_ref: reference to the Python object that triggered this action :param log_to_statsd: whether to update statsd counter for the action """ - payload_override = {} + payload_override = kwargs.copy() start = datetime.now() # yield a helper to add additional payload yield lambda **kwargs: payload_override.update(kwargs) @@ -359,3 +367,29 @@ def log( # pylint: disable=too-many-arguments,too-many-locals except SQLAlchemyError as ex: logging.error("DBEventLogger failed to log event(s)") logging.exception(ex) + + +class StdOutEventLogger(AbstractEventLogger): + """Event logger that prints to stdout for debugging purposes""" + + def log( # pylint: disable=too-many-arguments + self, + user_id: int | None, + action: str, + dashboard_id: int | None, + duration_ms: int | None, + slice_id: int | None, + referrer: str | None, + *args: Any, + **kwargs: Any, + ) -> None: + data = dict( # pylint: disable=use-dict-literal + user_id=user_id, + action=action, + dashboard_id=dashboard_id, + duration_ms=duration_ms, + slice_id=slice_id, + referrer=referrer, + **kwargs, + ) + print("StdOutEventLogger: ", data) diff --git a/superset/views/utils.py b/superset/views/utils.py index bf1099a6aa42a..86030b98099fd 100644 --- a/superset/views/utils.py +++ b/superset/views/utils.py @@ -151,10 +151,12 @@ def get_form_data( form_data: dict[str, Any] = initial_form_data or {} if has_request_context(): + json_data = request.get_json(cache=True) if request.is_json else {} + # chart data API requests are JSON - request_json_data = ( - request.json["queries"][0] - if request.is_json and "queries" in request.json + first_query = ( + json_data["queries"][0] + if "queries" in json_data and json_data["queries"] else None ) @@ -162,8 +164,8 @@ def get_form_data( request_form_data = request.form.get("form_data") request_args_data = request.args.get("form_data") - if request_json_data: - form_data.update(request_json_data) + if first_query: + form_data.update(first_query) if request_form_data: parsed_form_data = loads_request_json(request_form_data) # some chart data api requests are form_data diff --git a/tests/integration_tests/event_logger_tests.py b/tests/integration_tests/event_logger_tests.py index 3b20f6a91887a..98c3ea922c34a 100644 --- a/tests/integration_tests/event_logger_tests.py +++ b/tests/integration_tests/event_logger_tests.py @@ -144,8 +144,9 @@ def log( assert logger.records == [ { "records": [{"path": "/", "engine": "bar"}], + "database_id": None, "user_id": 2, - "duration": 15000.0, + "duration": 15000, } ] @@ -191,6 +192,7 @@ def log( "payload_override": {"engine": "sqlite"}, } ], + "database_id": None, "user_id": 2, "duration": 5558756000, } diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 3e5a80881549a..3b2e7690e141a 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -38,6 +38,7 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: database = query.database database.allow_dml = False database.apply_limit_to_sql.return_value = "SELECT 42 AS answer LIMIT 2" + database.mutate_sql_based_on_config.return_value = "SELECT 42 AS answer LIMIT 2" db_engine_spec = database.db_engine_spec db_engine_spec.is_select_query.return_value = True db_engine_spec.fetch_data.return_value = [(42,)] @@ -71,15 +72,16 @@ def test_execute_sql_statement_with_rls( from superset.sql_lab import execute_sql_statement sql_statement = "SELECT * FROM sales" + sql_statement_with_rls = f"{sql_statement} WHERE organization_id=42" + sql_statement_with_rls_and_limit = f"{sql_statement_with_rls} LIMIT 101" query = mocker.MagicMock() query.limit = 100 query.select_as_cta_used = False database = query.database database.allow_dml = False - database.apply_limit_to_sql.return_value = ( - "SELECT * FROM sales WHERE organization_id=42 LIMIT 101" - ) + database.apply_limit_to_sql.return_value = sql_statement_with_rls_and_limit + database.mutate_sql_based_on_config.return_value = sql_statement_with_rls_and_limit db_engine_spec = database.db_engine_spec db_engine_spec.is_select_query.return_value = True db_engine_spec.fetch_data.return_value = [(42,)]