diff --git a/superset/cli/update.py b/superset/cli/update.py index d31460c5e6c8e..ae4ad644c9a8c 100755 --- a/superset/cli/update.py +++ b/superset/cli/update.py @@ -31,6 +31,7 @@ import superset.utils.database as database_utils from superset.extensions import db +from superset.utils.core import override_user from superset.utils.encrypt import SecretsMigrator logger = logging.getLogger(__name__) @@ -54,23 +55,34 @@ def set_database_uri(database_name: str, uri: str, skip_create: bool) -> None: @click.command() @with_appcontext -def update_datasources_cache() -> None: +@click.option( + "--username", + "-u", + default=None, + help=( + "Specify which user should execute the underlying SQL queries. If undefined " + "defaults to the user registered with the database connection." + ), +) +def update_datasources_cache(username: Optional[str]) -> None: """Refresh sqllab datasources cache""" # pylint: disable=import-outside-toplevel + from superset import security_manager from superset.models.core import Database - for database in db.session.query(Database).all(): - if database.allow_multi_schema_metadata_fetch: - print("Fetching {} datasources ...".format(database.name)) - try: - database.get_all_table_names_in_database( - force=True, cache=True, cache_timeout=24 * 60 * 60 - ) - database.get_all_view_names_in_database( - force=True, cache=True, cache_timeout=24 * 60 * 60 - ) - except Exception as ex: # pylint: disable=broad-except - print("{}".format(str(ex))) + with override_user(security_manager.find_user(username)): + for database in db.session.query(Database).all(): + if database.allow_multi_schema_metadata_fetch: + print("Fetching {} datasources ...".format(database.name)) + try: + database.get_all_table_names_in_database( + force=True, cache=True, cache_timeout=24 * 60 * 60 + ) + database.get_all_view_names_in_database( + force=True, cache=True, cache_timeout=24 * 60 * 60 + ) + except Exception as ex: # pylint: disable=broad-except + print("{}".format(str(ex))) @click.command() diff --git a/superset/config.py b/superset/config.py index b507d67eae476..003f1fad13317 100644 --- a/superset/config.py +++ b/superset/config.py @@ -680,7 +680,7 @@ def _try_json_readsha(filepath: str, length: int) -> Optional[str]: # database, # query, # schema=None, -# user=None, +# user=None, # TODO(john-bodley): Deprecate in 3.0. # client=None, # security_manager=None, # log_params=None, @@ -1020,9 +1020,14 @@ def CSV_TO_HIVE_UPLOAD_DIRECTORY_FUNC( # pylint: disable=invalid-name # The use case is can be around adding some sort of comment header # with information such as the username and worker node information # -# def SQL_QUERY_MUTATOR(sql, user_name=user_name, security_manager=security_manager, database=database): +# def SQL_QUERY_MUTATOR( +# sql, +# user_name=user_name, # TODO(john-bodley): Deprecate in 3.0. +# security_manager=security_manager, +# database=database, +# ): # dttm = datetime.now().isoformat() -# return f"-- [SQL LAB] {username} {dttm}\n{sql}" +# return f"-- [SQL LAB] {user_name} {dttm}\n{sql}" # For backward compatibility, you can unpack any of the above arguments in your # function definition, but keep the **kwargs as the last argument to allow new args # to be added later without any errors. diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 489c483baf62a..4bf02b25cf894 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -120,6 +120,7 @@ from superset.utils.core import ( GenericDataType, get_column_name, + get_username, is_adhoc_column, MediumText, QueryObjectFilterClause, @@ -917,10 +918,9 @@ def mutate_query_from_config(self, sql: str) -> str: Typically adds comments to the query with context""" sql_query_mutator = config["SQL_QUERY_MUTATOR"] if sql_query_mutator: - username = utils.get_username() sql = sql_query_mutator( sql, - user_name=username, + user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0. security_manager=security_manager, database=self.database, ) diff --git a/superset/databases/commands/test_connection.py b/superset/databases/commands/test_connection.py index 6fec767a1f52d..7066974128591 100644 --- a/superset/databases/commands/test_connection.py +++ b/superset/databases/commands/test_connection.py @@ -38,6 +38,7 @@ from superset.exceptions import SupersetSecurityException, SupersetTimeoutException from superset.extensions import event_logger from superset.models.core import Database +from superset.utils.core import override_user logger = logging.getLogger(__name__) @@ -74,42 +75,43 @@ def run(self) -> None: database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - username = self._actor.username if self._actor is not None else None - engine = database.get_sqla_engine(user_name=username) - event_logger.log_with_context( - action="test_connection_attempt", - engine=database.db_engine_spec.__name__, - ) - with closing(engine.raw_connection()) as conn: - try: - alive = func_timeout( - int( - app.config[ - "TEST_DATABASE_CONNECTION_TIMEOUT" - ].total_seconds() - ), - engine.dialect.do_ping, - args=(conn,), - ) - except (sqlite3.ProgrammingError, RuntimeError): - # SQLite can't run on a separate thread, so ``func_timeout`` fails - # RuntimeError catches the equivalent error from duckdb. - alive = engine.dialect.do_ping(conn) - except FunctionTimedOut as ex: - raise SupersetTimeoutException( - error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, - message=( - "Please check your connection details and database settings, " - "and ensure that your database is accepting connections, " - "then try connecting again." - ), - level=ErrorLevel.ERROR, - extra={"sqlalchemy_uri": database.sqlalchemy_uri}, - ) from ex - except Exception: # pylint: disable=broad-except - alive = False - if not alive: - raise DBAPIError(None, None, None) + + with override_user(self._actor): + engine = database.get_sqla_engine() + event_logger.log_with_context( + action="test_connection_attempt", + engine=database.db_engine_spec.__name__, + ) + with closing(engine.raw_connection()) as conn: + try: + alive = func_timeout( + int( + app.config[ + "TEST_DATABASE_CONNECTION_TIMEOUT" + ].total_seconds() + ), + engine.dialect.do_ping, + args=(conn,), + ) + except (sqlite3.ProgrammingError, RuntimeError): + # SQLite can't run on a separate thread, so ``func_timeout`` fails + # RuntimeError catches the equivalent error from duckdb. + alive = engine.dialect.do_ping(conn) + except FunctionTimedOut as ex: + raise SupersetTimeoutException( + error_type=SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, + message=( + "Please check your connection details and database " + "settings, and ensure that your database is accepting " + "connections, then try connecting again." + ), + level=ErrorLevel.ERROR, + extra={"sqlalchemy_uri": database.sqlalchemy_uri}, + ) from ex + except Exception: # pylint: disable=broad-except + alive = False + if not alive: + raise DBAPIError(None, None, None) # Log succesful connection test with engine event_logger.log_with_context( diff --git a/superset/databases/commands/validate.py b/superset/databases/commands/validate.py index fa05dc7c3030c..145965fc641fc 100644 --- a/superset/databases/commands/validate.py +++ b/superset/databases/commands/validate.py @@ -35,6 +35,7 @@ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.extensions import event_logger from superset.models.core import Database +from superset.utils.core import override_user BYPASS_VALIDATION_ENGINES = {"bigquery"} @@ -115,22 +116,23 @@ def run(self) -> None: ) database.set_sqlalchemy_uri(sqlalchemy_uri) database.db_engine_spec.mutate_db_for_connection_test(database) - username = self._actor.username if self._actor is not None else None - engine = database.get_sqla_engine(user_name=username) - try: - with closing(engine.raw_connection()) as conn: - alive = engine.dialect.do_ping(conn) - except Exception as ex: - url = make_url_safe(sqlalchemy_uri) - context = { - "hostname": url.host, - "password": url.password, - "port": url.port, - "username": url.username, - "database": url.database, - } - errors = database.db_engine_spec.extract_errors(ex, context) - raise DatabaseTestConnectionFailedError(errors) from ex + + with override_user(self._actor): + engine = database.get_sqla_engine() + try: + with closing(engine.raw_connection()) as conn: + alive = engine.dialect.do_ping(conn) + except Exception as ex: + url = make_url_safe(sqlalchemy_uri) + context = { + "hostname": url.host, + "password": url.password, + "port": url.port, + "username": url.username, + "database": url.database, + } + errors = database.db_engine_spec.extract_errors(ex, context) + raise DatabaseTestConnectionFailedError(errors) from ex if not alive: raise DatabaseOfflineError( diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index c31b0a7229bd4..6a2ddc5e5c3f4 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -40,7 +40,7 @@ import sqlparse from apispec import APISpec from apispec.ext.marshmallow import MarshmallowPlugin -from flask import current_app, g +from flask import current_app from flask_babel import gettext as __, lazy_gettext as _ from marshmallow import fields, Schema from marshmallow.validate import Range @@ -64,7 +64,7 @@ from superset.sql_parse import ParsedQuery, Table from superset.superset_typing import ResultSetColumnType from superset.utils import core as utils -from superset.utils.core import ColumnSpec, GenericDataType +from superset.utils.core import ColumnSpec, GenericDataType, get_username from superset.utils.hashing import md5_sha_from_str from superset.utils.network import is_hostname_valid, is_port_open @@ -392,10 +392,7 @@ def get_engine( schema: Optional[str] = None, source: Optional[str] = None, ) -> Engine: - user_name = utils.get_username() - return database.get_sqla_engine( - schema=schema, nullpool=True, user_name=user_name, source=source - ) + return database.get_sqla_engine(schema=schema, source=source) @classmethod def get_timestamp_expr( @@ -1158,15 +1155,12 @@ def query_cost_formatter( raise Exception("Database does not support cost estimation") @classmethod - def process_statement( - cls, statement: str, database: "Database", user_name: str - ) -> str: + def process_statement(cls, statement: str, database: "Database") -> str: """ Process a SQL statement by stripping and mutating it. :param statement: A single SQL statement :param database: Database instance - :param user_name: Effective username :return: Dictionary with different costs """ parsed_query = ParsedQuery(statement) @@ -1175,7 +1169,7 @@ def process_statement( if sql_query_mutator: sql = sql_query_mutator( sql, - user_name=user_name, + user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0. security_manager=security_manager, database=database, ) @@ -1198,7 +1192,6 @@ def estimate_query_cost( if not cls.get_allow_cost_estimate(extra): raise Exception("Database does not support cost estimation") - user_name = g.user.username if g.user and hasattr(g.user, "username") else None parsed_query = sql_parse.ParsedQuery(sql) statements = parsed_query.get_statements() @@ -1207,9 +1200,7 @@ def estimate_query_cost( with closing(engine.raw_connection()) as conn: cursor = conn.cursor() for statement in statements: - processed_statement = cls.process_statement( - statement, database, user_name - ) + processed_statement = cls.process_statement(statement, database) costs.append(cls.estimate_statement_cost(processed_statement, cursor)) return costs diff --git a/superset/models/core.py b/superset/models/core.py index 8d16ea39f8c5f..1385157d8bab9 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -61,6 +61,7 @@ from superset.models.tags import FavStarUpdater from superset.result_set import SupersetResultSet from superset.utils import cache as cache_util, core as utils +from superset.utils.core import get_username from superset.utils.memoized import memoized config = app.config @@ -322,29 +323,21 @@ def set_sqlalchemy_uri(self, uri: str) -> None: conn.password = PASSWORD_MASK if conn.password else None self.sqlalchemy_uri = str(conn) # hides the password - def get_effective_user( - self, - object_url: URL, - user_name: Optional[str] = None, - ) -> Optional[str]: + def get_effective_user(self, object_url: URL) -> Optional[str]: """ Get the effective user, especially during impersonation. + :param object_url: SQL Alchemy URL object - :param user_name: Default username :return: The effective username """ - effective_username = None - if self.impersonate_user: - effective_username = object_url.username - if user_name: - effective_username = user_name - elif ( - hasattr(g, "user") - and hasattr(g.user, "username") - and g.user.username is not None - ): - effective_username = g.user.username - return effective_username + + return ( # pylint: disable=used-before-assignment + username + if (username := get_username()) + else object_url.username + if self.impersonate_user + else None + ) @memoized( watch=( @@ -358,13 +351,12 @@ def get_sqla_engine( self, schema: Optional[str] = None, nullpool: bool = True, - user_name: Optional[str] = None, source: Optional[utils.QuerySource] = None, ) -> Engine: extra = self.get_extra() sqlalchemy_url = make_url_safe(self.sqlalchemy_uri_decrypted) self.db_engine_spec.adjust_database_uri(sqlalchemy_url, schema) - effective_username = self.get_effective_user(sqlalchemy_url, user_name) + effective_username = self.get_effective_user(sqlalchemy_url) # If using MySQL or Presto for example, will set url.username # If using Hive, will not do anything yet since that relies on a # configuration parameter instead. @@ -421,12 +413,9 @@ def get_df( # pylint: disable=too-many-locals sql: str, schema: Optional[str] = None, mutator: Optional[Callable[[pd.DataFrame], None]] = None, - username: Optional[str] = None, ) -> pd.DataFrame: sqls = self.db_engine_spec.parse_sql(sql) - - engine = self.get_sqla_engine(schema=schema, user_name=username) - username = utils.get_username() or username + engine = self.get_sqla_engine(schema) def needs_conversion(df_series: pd.Series) -> bool: return ( @@ -437,7 +426,14 @@ def needs_conversion(df_series: pd.Series) -> bool: def _log_query(sql: str) -> None: if log_query: - log_query(engine.url, sql, schema, username, __name__, security_manager) + log_query( + engine.url, + sql, + schema, + get_username(), + __name__, + security_manager, + ) with closing(engine.raw_connection()) as conn: cursor = conn.cursor() diff --git a/superset/reports/commands/alert.py b/superset/reports/commands/alert.py index f5879a037f378..7a18996df3b3c 100644 --- a/superset/reports/commands/alert.py +++ b/superset/reports/commands/alert.py @@ -25,7 +25,7 @@ from celery.exceptions import SoftTimeLimitExceeded from flask_babel import lazy_gettext as _ -from superset import app, jinja_context +from superset import app, jinja_context, security_manager from superset.commands.base import BaseCommand from superset.models.reports import ReportSchedule, ReportScheduleValidatorType from superset.reports.commands.exceptions import ( @@ -36,6 +36,7 @@ AlertQueryTimeout, AlertValidatorConfigError, ) +from superset.utils.core import override_user logger = logging.getLogger(__name__) @@ -145,18 +146,21 @@ def _execute_query(self) -> pd.DataFrame: limited_rendered_sql = self._report_schedule.database.apply_limit_to_sql( rendered_sql, ALERT_SQL_LIMIT ) - query_username = app.config["THUMBNAIL_SELENIUM_USER"] - start = default_timer() - df = self._report_schedule.database.get_df( - sql=limited_rendered_sql, username=query_username - ) - stop = default_timer() - logger.info( - "Query for %s took %.2f ms", - self._report_schedule.name, - (stop - start) * 1000.0, - ) - return df + + with override_user( + security_manager.find_user( + username=app.config["THUMBNAIL_SELENIUM_USER"] + ) + ): + start = default_timer() + df = self._report_schedule.database.get_df(sql=limited_rendered_sql) + stop = default_timer() + logger.info( + "Query for %s took %.2f ms", + self._report_schedule.name, + (stop - start) * 1000.0, + ) + return df except SoftTimeLimitExceeded as ex: logger.warning("A timeout occurred while executing the alert query: %s", ex) raise AlertQueryTimeout() from ex diff --git a/superset/reports/commands/execute.py b/superset/reports/commands/execute.py index c006d007c4c57..9dc3cc1955641 100644 --- a/superset/reports/commands/execute.py +++ b/superset/reports/commands/execute.py @@ -25,7 +25,7 @@ from flask_appbuilder.security.sqla.models import User from sqlalchemy.orm import Session -from superset import app +from superset import app, security_manager from superset.commands.base import BaseCommand from superset.commands.exceptions import CommandException from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType @@ -179,11 +179,10 @@ def _get_url( **kwargs, ) - def _get_user(self) -> User: - user = ( - self._session.query(User) - .filter(User.username == app.config["THUMBNAIL_SELENIUM_USER"]) - .one_or_none() + @staticmethod + def _get_user() -> User: + user = security_manager.find_user( + username=app.config["THUMBNAIL_SELENIUM_USER"] ) if not user: raise ReportScheduleSelleniumUserNotFoundError() diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 3f43c0faae66c..cd3684b1d8ea8 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -50,7 +50,13 @@ from superset.sql_parse import CtasMethod, insert_rls, ParsedQuery from superset.sqllab.limiting_factor import LimitingFactor from superset.utils.celery import session_scope -from superset.utils.core import json_iso_dttm_ser, QuerySource, zlib_compress +from superset.utils.core import ( + get_username, + json_iso_dttm_ser, + override_user, + QuerySource, + zlib_compress, +) from superset.utils.dates import now_as_float from superset.utils.decorators import stats_timing @@ -155,37 +161,35 @@ def get_sql_results( # pylint: disable=too-many-arguments rendered_query: str, return_results: bool = True, store_results: bool = False, - user_name: Optional[str] = None, + username: Optional[str] = None, start_time: Optional[float] = None, expand_data: bool = False, log_params: Optional[Dict[str, Any]] = None, ) -> Optional[Dict[str, Any]]: """Executes the sql query returns the results.""" with session_scope(not ctask.request.called_directly) as session: - - try: - return execute_sql_statements( - query_id, - rendered_query, - return_results, - store_results, - user_name, - session=session, - start_time=start_time, - expand_data=expand_data, - log_params=log_params, - ) - except Exception as ex: # pylint: disable=broad-except - logger.debug("Query %d: %s", query_id, ex) - stats_logger.incr("error_sqllab_unhandled") - query = get_query(query_id, session) - return handle_query_error(ex, query, session) + with override_user(security_manager.find_user(username)): + try: + return execute_sql_statements( + query_id, + rendered_query, + return_results, + store_results, + session=session, + start_time=start_time, + expand_data=expand_data, + log_params=log_params, + ) + except Exception as ex: # pylint: disable=broad-except + logger.debug("Query %d: %s", query_id, ex) + stats_logger.incr("error_sqllab_unhandled") + query = get_query(query_id, session) + return handle_query_error(ex, query, session) def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals,too-many-statements sql_statement: str, query: Query, - user_name: Optional[str], session: Session, cursor: Any, log_params: Optional[Dict[str, Any]], @@ -204,7 +208,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals parsed_query._parsed[0], # pylint: disable=protected-access database.id, query.schema, - username=user_name, + username=get_username(), ) ) ) @@ -246,7 +250,10 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals # Hook to allow environment-specific mutation (usually comments) to the SQL sql = SQL_QUERY_MUTATOR( - sql, user_name=user_name, security_manager=security_manager, database=database + sql, + user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0. + security_manager=security_manager, + database=database, ) try: query.executed_sql = sql @@ -255,7 +262,7 @@ def execute_sql_statement( # pylint: disable=too-many-arguments,too-many-locals query.database.sqlalchemy_uri, query.executed_sql, query.schema, - user_name, + get_username(), __name__, security_manager, log_params, @@ -375,7 +382,6 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca rendered_query: str, return_results: bool, store_results: bool, - user_name: Optional[str], session: Session, start_time: Optional[float], expand_data: bool, @@ -452,12 +458,7 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca ) ) - engine = database.get_sqla_engine( - schema=query.schema, - nullpool=True, - user_name=user_name, - source=QuerySource.SQL_LAB, - ) + engine = database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB) # Sharing a single connection and cursor across the # execution of all statements (if many) with closing(engine.raw_connection()) as conn: @@ -490,7 +491,6 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca result_set = execute_sql_statement( statement, query, - user_name, session, cursor, log_params, @@ -597,7 +597,7 @@ def execute_sql_statements( # pylint: disable=too-many-arguments, too-many-loca return None -def cancel_query(query: Query, user_name: Optional[str] = None) -> bool: +def cancel_query(query: Query) -> bool: """ Cancel a running query. @@ -605,7 +605,6 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> bool: action is required. :param query: Query to cancel - :param user_name: Default username :return: True if query cancelled successfully, False otherwise """ @@ -616,12 +615,7 @@ def cancel_query(query: Query, user_name: Optional[str] = None) -> bool: if cancel_query_id is None: return False - engine = query.database.get_sqla_engine( - schema=query.schema, - nullpool=True, - user_name=user_name, - source=QuerySource.SQL_LAB, - ) + engine = query.database.get_sqla_engine(query.schema, source=QuerySource.SQL_LAB) with closing(engine.raw_connection()) as conn: with closing(conn.cursor()) as cursor: diff --git a/superset/sql_validators/presto_db.py b/superset/sql_validators/presto_db.py index 6d0311f120efb..70b324c900736 100644 --- a/superset/sql_validators/presto_db.py +++ b/superset/sql_validators/presto_db.py @@ -20,13 +20,11 @@ from contextlib import closing from typing import Any, Dict, List, Optional -from flask import g - from superset import app, security_manager from superset.models.core import Database from superset.sql_parse import ParsedQuery from superset.sql_validators.base import BaseSQLValidator, SQLValidationAnnotation -from superset.utils.core import QuerySource +from superset.utils.core import get_username, QuerySource MAX_ERROR_ROWS = 10 @@ -45,7 +43,10 @@ class PrestoDBSQLValidator(BaseSQLValidator): @classmethod def validate_statement( - cls, statement: str, database: Database, cursor: Any, user_name: str + cls, + statement: str, + database: Database, + cursor: Any, ) -> Optional[SQLValidationAnnotation]: # pylint: disable=too-many-locals db_engine_spec = database.db_engine_spec @@ -57,7 +58,7 @@ def validate_statement( if sql_query_mutator: sql = sql_query_mutator( sql, - user_name=user_name, + user_name=get_username(), # TODO(john-bodley): Deprecate in 3.0. security_manager=security_manager, database=database, ) @@ -157,26 +158,18 @@ def validate( For example, "SELECT 1 FROM default.mytable" becomes "EXPLAIN (TYPE VALIDATE) SELECT 1 FROM default.mytable. """ - user_name = g.user.username if g.user and hasattr(g.user, "username") else None parsed_query = ParsedQuery(sql) statements = parsed_query.get_statements() logger.info("Validating %i statement(s)", len(statements)) - engine = database.get_sqla_engine( - schema=schema, - nullpool=True, - user_name=user_name, - source=QuerySource.SQL_LAB, - ) + engine = database.get_sqla_engine(schema, source=QuerySource.SQL_LAB) # Sharing a single connection and cursor across the # execution of all statements (if many) annotations: List[SQLValidationAnnotation] = [] with closing(engine.raw_connection()) as conn: cursor = conn.cursor() for statement in parsed_query.get_statements(): - annotation = cls.validate_statement( - statement, database, cursor, user_name - ) + annotation = cls.validate_statement(statement, database, cursor) if annotation: annotations.append(annotation) logger.debug("Validation found %i error(s)", len(annotations)) diff --git a/superset/sqllab/sql_json_executer.py b/superset/sqllab/sql_json_executer.py index 77023b341531c..3d55047b41042 100644 --- a/superset/sqllab/sql_json_executer.py +++ b/superset/sqllab/sql_json_executer.py @@ -22,7 +22,6 @@ from abc import ABC from typing import Any, Callable, Dict, Optional, TYPE_CHECKING -from flask import g from flask_babel import gettext as __ from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -34,6 +33,7 @@ ) from superset.sqllab.command_status import SqlJsonExecutionStatus from superset.utils import core as utils +from superset.utils.core import get_username from superset.utils.dates import now_as_float if TYPE_CHECKING: @@ -139,9 +139,7 @@ def _get_sql_results( rendered_query, return_results=True, store_results=self._is_store_results(execution_context), - user_name=g.user.username - if g.user and hasattr(g.user, "username") - else None, + username=get_username(), expand_data=execution_context.expand_data, log_params=log_params, ) @@ -174,9 +172,7 @@ def execute( rendered_query, return_results=False, store_results=not execution_context.select_as_cta, - user_name=g.user.username - if g.user and hasattr(g.user, "username") - else None, + username=get_username(), start_time=now_as_float(), expand_data=execution_context.expand_data, log_params=log_params, diff --git a/superset/utils/core.py b/superset/utils/core.py index f0750ffb6fa60..692de16e9182d 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -32,6 +32,7 @@ import traceback import uuid import zlib +from contextlib import contextmanager from datetime import date, datetime, time, timedelta from distutils.util import strtobool from email.mime.application import MIMEApplication @@ -1408,6 +1409,30 @@ def get_username() -> Optional[str]: return None +@contextmanager +def override_user(user: Optional[User]) -> Iterator[Any]: + """ + Temporarily override the current user (if defined) per `flask.g`. + + Sometimes, often in the context of async Celery tasks, it is useful to switch the + current user (which may be undefined) to different one, execute some SQLAlchemy + tasks and then revert back to the original one. + + :param user: The override user + """ + + # pylint: disable=assigning-non-slot + if hasattr(g, "user"): + current = g.user + g.user = user + yield + g.user = current + else: + g.user = user + yield + delattr(g, "user") + + def parse_ssl_cert(certificate: str) -> _Certificate: """ Parses the contents of a certificate and returns a valid certificate object diff --git a/superset/views/core.py b/superset/views/core.py index 660c527c0a882..fa3437e47b9c0 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -1367,11 +1367,7 @@ def testconn(self) -> FlaskResponse: # pylint: disable=no-self-use ) database.set_sqlalchemy_uri(uri) database.db_engine_spec.mutate_db_for_connection_test(database) - - username = ( - g.user.username if g.user and hasattr(g.user, "username") else None - ) - engine = database.get_sqla_engine(user_name=username) + engine = database.get_sqla_engine() with closing(engine.raw_connection()) as conn: if engine.dialect.do_ping(conn): @@ -2298,7 +2294,7 @@ def stop_query(self) -> FlaskResponse: ) return self.json_response("OK") - if not sql_lab.cancel_query(query, g.user.username if g.user else None): + if not sql_lab.cancel_query(query): raise SupersetCancelQueryException("Could not cancel query") query.status = QueryStatus.STOPPED diff --git a/tests/conftest.py b/tests/conftest.py index 92f9b10d955ad..2c129965f1bd6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -28,8 +28,11 @@ from typing import Callable, TYPE_CHECKING from unittest.mock import MagicMock, Mock, PropertyMock +from flask import Flask +from flask.ctx import AppContext from pytest import fixture +from superset.app import create_app from tests.example_data.data_loading.pandas.pandas_data_loader import PandasDataLoader from tests.example_data.data_loading.pandas.pands_data_loading_conf import ( PandasLoaderConfigurations, diff --git a/tests/integration_tests/access_tests.py b/tests/integration_tests/access_tests.py index d26b07504ced4..c2319ff5b52c1 100644 --- a/tests/integration_tests/access_tests.py +++ b/tests/integration_tests/access_tests.py @@ -21,6 +21,8 @@ from unittest import mock import pytest +from flask import g +from flask.ctx import AppContext from sqlalchemy import inspect from tests.integration_tests.fixtures.birth_names_dashboard import ( @@ -41,6 +43,7 @@ from superset.connectors.sqla.models import SqlaTable from superset.models import core as models from superset.models.datasource_access_request import DatasourceAccessRequest +from superset.utils.core import get_username, override_user from superset.utils.database import get_example_database from .base_tests import SupersetTestCase @@ -86,7 +89,7 @@ SCHEMA_ACCESS_ROLE = "schema_access_role" -def create_access_request(session, ds_type, ds_name, role_name, user_name): +def create_access_request(session, ds_type, ds_name, role_name, username): ds_class = ConnectorRegistry.sources[ds_type] # TODO: generalize datasource names if ds_type == "table": @@ -102,7 +105,7 @@ def create_access_request(session, ds_type, ds_name, role_name, user_name): access_request = DatasourceAccessRequest( datasource_id=ds.id, datasource_type=ds_type, - created_by_fk=security_manager.find_user(username=user_name).id, + created_by_fk=security_manager.find_user(username=username).id, ) session.add(access_request) session.commit() @@ -565,5 +568,46 @@ def test_request_access(self): session.commit() +@pytest.mark.parametrize( + "username", + [ + None, + "gamma", + ], +) +def test_get_username(app_context: AppContext, username: str) -> None: + assert not hasattr(g, "user") + assert get_username() is None + + g.user = security_manager.find_user(username) + assert get_username() == username + + +@pytest.mark.parametrize( + "username", + [ + None, + "gamma", + ], +) +def test_override_user(app_context: AppContext, username: str) -> None: + admin = security_manager.find_user(username="admin") + user = security_manager.find_user(username) + + assert not hasattr(g, "user") + + with override_user(user): + assert g.user == user + + assert not hasattr(g, "user") + + g.user = admin + + with override_user(user): + assert g.user == user + + assert g.user == admin + + if __name__ == "__main__": unittest.main() diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index c2d2ef990e05e..ebad0dffd63b6 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -329,7 +329,7 @@ def run_sql( self, sql, client_id=None, - user_name=None, + username=None, raise_on_error=False, query_limit=None, database_name="examples", @@ -340,9 +340,9 @@ def run_sql( ctas_method=CtasMethod.TABLE, template_params="{}", ): - if user_name: + if username: self.logout() - self.login(username=(user_name or "admin")) + self.login(username=username) dbid = SupersetTestCase.get_database_by_name(database_name).id json_payload = { "database_id": dbid, @@ -427,14 +427,14 @@ def validate_sql( self, sql, client_id=None, - user_name=None, + username=None, raise_on_error=False, database_name="examples", template_params=None, ): - if user_name: + if username: self.logout() - self.login(username=(user_name if user_name else "admin")) + self.login(username=username) dbid = SupersetTestCase.get_database_by_name(database_name).id resp = self.get_json_resp( "/superset/validate_sql_json/", diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 68cd844d874b9..58943246c545b 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -1064,7 +1064,7 @@ def test_explore_json_dist_bar_order(self): LIMIT 10; """, client_id="client_id_1", - user_name="admin", + username="admin", ) count_ds = [] count_name = [] @@ -1454,7 +1454,7 @@ def test_sqllab_backend_persistence_payload(self): self.run_sql( "SELECT name FROM birth_names", "client_id_1", - user_name=username, + username=username, raise_on_error=True, sql_editor_id=str(tab_state_id), ) @@ -1462,7 +1462,7 @@ def test_sqllab_backend_persistence_payload(self): self.run_sql( "SELECT name FROM birth_names", "client_id_2", - user_name=username, + username=username, raise_on_error=True, ) diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index a69ffc931d41d..ace75da35a88f 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -20,8 +20,10 @@ import unittest from unittest import mock +from superset import security_manager from superset.connectors.sqla.models import SqlaTable from superset.exceptions import SupersetException +from superset.utils.core import override_user from tests.integration_tests.fixtures.birth_names_dashboard import ( load_birth_names_dashboard_with_slices, load_birth_names_data, @@ -112,21 +114,22 @@ def test_database_schema_mysql(self): ) def test_database_impersonate_user(self): uri = "mysql://root@localhost" - example_user = "giuseppe" + example_user = security_manager.find_user(username="gamma") model = Database(database_name="test_database", sqlalchemy_uri=uri) - model.impersonate_user = True - user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username - self.assertEqual(example_user, user_name) + with override_user(example_user): + model.impersonate_user = True + username = make_url(model.get_sqla_engine().url).username + self.assertEqual(example_user.username, username) - model.impersonate_user = False - user_name = make_url(model.get_sqla_engine(user_name=example_user).url).username - self.assertNotEqual(example_user, user_name) + model.impersonate_user = False + username = make_url(model.get_sqla_engine().url).username + self.assertNotEqual(example_user.username, username) @mock.patch("superset.models.core.create_engine") def test_impersonate_user_presto(self, mocked_create_engine): uri = "presto://localhost" - principal_user = "logged_in_user" + principal_user = security_manager.find_user(username="gamma") extra = """ { "metadata_params": {}, @@ -142,64 +145,66 @@ def test_impersonate_user_presto(self, mocked_create_engine): } """ - model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra) - - model.impersonate_user = True - model.get_sqla_engine(user_name=principal_user) - call_args = mocked_create_engine.call_args + with override_user(principal_user): + model = Database( + database_name="test_database", sqlalchemy_uri=uri, extra=extra + ) + model.impersonate_user = True + model.get_sqla_engine() + call_args = mocked_create_engine.call_args - assert str(call_args[0][0]) == "presto://logged_in_user@localhost" + assert str(call_args[0][0]) == "presto://gamma@localhost" - assert call_args[1]["connect_args"] == { - "protocol": "https", - "username": "original_user", - "password": "original_user_password", - "principal_username": "logged_in_user", - } + assert call_args[1]["connect_args"] == { + "protocol": "https", + "username": "original_user", + "password": "original_user_password", + "principal_username": "gamma", + } - model.impersonate_user = False - model.get_sqla_engine(user_name=principal_user) - call_args = mocked_create_engine.call_args + model.impersonate_user = False + model.get_sqla_engine() + call_args = mocked_create_engine.call_args - assert str(call_args[0][0]) == "presto://localhost" + assert str(call_args[0][0]) == "presto://localhost" - assert call_args[1]["connect_args"] == { - "protocol": "https", - "username": "original_user", - "password": "original_user_password", - } + assert call_args[1]["connect_args"] == { + "protocol": "https", + "username": "original_user", + "password": "original_user_password", + } @mock.patch("superset.models.core.create_engine") def test_impersonate_user_trino(self, mocked_create_engine): - uri = "trino://localhost" - principal_user = "logged_in_user" + principal_user = security_manager.find_user(username="gamma") - model = Database(database_name="test_database", sqlalchemy_uri=uri) - - model.impersonate_user = True - model.get_sqla_engine(user_name=principal_user) - call_args = mocked_create_engine.call_args - - assert str(call_args[0][0]) == "trino://localhost" + with override_user(principal_user): + model = Database( + database_name="test_database", sqlalchemy_uri="trino://localhost" + ) + model.impersonate_user = True + model.get_sqla_engine() + call_args = mocked_create_engine.call_args - assert call_args[1]["connect_args"] == { - "user": "logged_in_user", - } + assert str(call_args[0][0]) == "trino://localhost" + assert call_args[1]["connect_args"] == {"user": "gamma"} - uri = "trino://original_user:original_user_password@localhost" - model = Database(database_name="test_database", sqlalchemy_uri=uri) - model.impersonate_user = True - model.get_sqla_engine(user_name=principal_user) - call_args = mocked_create_engine.call_args + model = Database( + database_name="test_database", + sqlalchemy_uri="trino://original_user:original_user_password@localhost", + ) - assert str(call_args[0][0]) == "trino://original_user@localhost" + model.impersonate_user = True + model.get_sqla_engine() + call_args = mocked_create_engine.call_args - assert call_args[1]["connect_args"] == {"user": "logged_in_user"} + assert str(call_args[0][0]) == "trino://original_user@localhost" + assert call_args[1]["connect_args"] == {"user": "gamma"} @mock.patch("superset.models.core.create_engine") def test_impersonate_user_hive(self, mocked_create_engine): uri = "hive://localhost" - principal_user = "logged_in_user" + principal_user = security_manager.find_user(username="gamma") extra = """ { "metadata_params": {}, @@ -215,32 +220,34 @@ def test_impersonate_user_hive(self, mocked_create_engine): } """ - model = Database(database_name="test_database", sqlalchemy_uri=uri, extra=extra) + with override_user(principal_user): + model = Database( + database_name="test_database", sqlalchemy_uri=uri, extra=extra + ) + model.impersonate_user = True + model.get_sqla_engine() + call_args = mocked_create_engine.call_args - model.impersonate_user = True - model.get_sqla_engine(user_name=principal_user) - call_args = mocked_create_engine.call_args + assert str(call_args[0][0]) == "hive://localhost" - assert str(call_args[0][0]) == "hive://localhost" + assert call_args[1]["connect_args"] == { + "protocol": "https", + "username": "original_user", + "password": "original_user_password", + "configuration": {"hive.server2.proxy.user": "gamma"}, + } - assert call_args[1]["connect_args"] == { - "protocol": "https", - "username": "original_user", - "password": "original_user_password", - "configuration": {"hive.server2.proxy.user": "logged_in_user"}, - } + model.impersonate_user = False + model.get_sqla_engine() + call_args = mocked_create_engine.call_args - model.impersonate_user = False - model.get_sqla_engine(user_name=principal_user) - call_args = mocked_create_engine.call_args + assert str(call_args[0][0]) == "hive://localhost" - assert str(call_args[0][0]) == "hive://localhost" - - assert call_args[1]["connect_args"] == { - "protocol": "https", - "username": "original_user", - "password": "original_user_password", - } + assert call_args[1]["connect_args"] == { + "protocol": "https", + "username": "original_user", + "password": "original_user_password", + } @pytest.mark.usefixtures("load_energy_table_with_slice") def test_select_star(self): @@ -345,19 +352,6 @@ def test_multi_statement(self): df = main_db.get_df("USE superset; SELECT ';';", None) self.assertEqual(df.iat[0, 0], ";") - @mock.patch("superset.models.core.Database.get_sqla_engine") - def test_username_param(self, mocked_get_sqla_engine): - main_db = get_example_database() - main_db.impersonate_user = True - test_username = "test_username_param" - - if main_db.backend == "mysql": - main_db.get_df("USE superset; SELECT 1", username=test_username) - mocked_get_sqla_engine.assert_called_with( - schema=None, - user_name="test_username_param", - ) - @mock.patch("superset.models.core.create_engine") def test_get_sqla_engine(self, mocked_create_engine): model = Database( diff --git a/tests/integration_tests/sql_validator_tests.py b/tests/integration_tests/sql_validator_tests.py index 57f31ba4b750d..ff4c74fa45fba 100644 --- a/tests/integration_tests/sql_validator_tests.py +++ b/tests/integration_tests/sql_validator_tests.py @@ -187,7 +187,7 @@ def tearDown(self): "message": "your query isn't how I like it", } - @patch("superset.sql_validators.presto_db.g") + @patch("superset.utils.core.g") def test_validator_success(self, flask_g): flask_g.user.username = "nobody" sql = "SELECT 1 FROM default.notarealtable" @@ -197,7 +197,7 @@ def test_validator_success(self, flask_g): self.assertEqual([], errors) - @patch("superset.sql_validators.presto_db.g") + @patch("superset.utils.core.g") def test_validator_db_error(self, flask_g): flask_g.user.username = "nobody" sql = "SELECT 1 FROM default.notarealtable" @@ -209,7 +209,7 @@ def test_validator_db_error(self, flask_g): with self.assertRaises(PrestoSQLValidationError): self.validator.validate(sql, schema, self.database) - @patch("superset.sql_validators.presto_db.g") + @patch("superset.utils.core.g") def test_validator_unexpected_error(self, flask_g): flask_g.user.username = "nobody" sql = "SELECT 1 FROM default.notarealtable" @@ -221,7 +221,7 @@ def test_validator_unexpected_error(self, flask_g): with self.assertRaises(Exception): self.validator.validate(sql, schema, self.database) - @patch("superset.sql_validators.presto_db.g") + @patch("superset.utils.core.g") def test_validator_query_error(self, flask_g): flask_g.user.username = "nobody" sql = "SELECT 1 FROM default.notarealtable" diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 49c6a771e5ad3..7ded842ef7afd 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -68,9 +68,9 @@ class TestSqlLab(SupersetTestCase): def run_some_queries(self): db.session.query(Query).delete() db.session.commit() - self.run_sql(QUERY_1, client_id="client_id_1", user_name="admin") - self.run_sql(QUERY_2, client_id="client_id_3", user_name="admin") - self.run_sql(QUERY_3, client_id="client_id_2", user_name="gamma_sqllab") + self.run_sql(QUERY_1, client_id="client_id_1", username="admin") + self.run_sql(QUERY_2, client_id="client_id_3", username="admin") + self.run_sql(QUERY_3, client_id="client_id_2", username="gamma_sqllab") self.logout() def tearDown(self): @@ -162,7 +162,7 @@ def test_sql_json_to_saved_query_info(self): db.session.commit() with freeze_time(datetime.now().isoformat(timespec="seconds")): - self.run_sql(sql_statement, "1") + self.run_sql(sql_statement, "1", username="admin") saved_query_ = ( db.session.query(SavedQuery) .filter( @@ -248,7 +248,7 @@ def test_sql_json_has_access(self): # Gamma user, with sqllab and db permission self.create_user_with_roles("Gagarin", ["ExampleDBAccess", "Gamma", "sql_lab"]) - data = self.run_sql(QUERY_1, "1", user_name="Gagarin") + data = self.run_sql(QUERY_1, "1", username="Gagarin") db.session.query(Query).delete() db.session.commit() self.assertLess(0, len(data["data"])) @@ -278,14 +278,14 @@ def test_sql_json_schema_access(self): ) data = self.run_sql( - f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", user_name="SchemaUser" + f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "3", username="SchemaUser" ) self.assertEqual(1, len(data["data"])) data = self.run_sql( f"SELECT * FROM {CTAS_SCHEMA_NAME}.test_table", "4", - user_name="SchemaUser", + username="SchemaUser", schema=CTAS_SCHEMA_NAME, ) self.assertEqual(1, len(data["data"])) @@ -295,7 +295,7 @@ def test_sql_json_schema_access(self): data = self.run_sql( "SELECT * FROM test_table", "5", - user_name="SchemaUser", + username="SchemaUser", schema=CTAS_SCHEMA_NAME, ) self.assertEqual(1, len(data["data"])) @@ -441,7 +441,7 @@ def test_alias_duplicate(self): self.run_sql( "SELECT name as col, gender as col FROM birth_names LIMIT 10", client_id="2e2df3", - user_name="admin", + username="admin", raise_on_error=True, ) @@ -747,7 +747,6 @@ def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query rendered_query=sql, return_results=True, store_results=False, - user_name="admin", session=mock_session, start_time=None, expand_data=False, @@ -758,7 +757,6 @@ def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query mock.call( "SET @value = 42", mock_query, - "admin", mock_session, mock_cursor, None, @@ -767,7 +765,6 @@ def test_execute_sql_statements(self, mock_execute_sql_statement, mock_get_query mock.call( "SELECT @value AS foo", mock_query, - "admin", mock_session, mock_cursor, None, @@ -804,7 +801,6 @@ def test_execute_sql_statements_no_results_backend( rendered_query=sql, return_results=True, store_results=False, - user_name="admin", session=mock_session, start_time=None, expand_data=False, @@ -858,7 +854,6 @@ def test_execute_sql_statements_ctas( rendered_query=sql, return_results=True, store_results=False, - user_name="admin", session=mock_session, start_time=None, expand_data=False, @@ -869,7 +864,6 @@ def test_execute_sql_statements_ctas( mock.call( "SET @value = 42", mock_query, - "admin", mock_session, mock_cursor, None, @@ -878,7 +872,6 @@ def test_execute_sql_statements_ctas( mock.call( "SELECT @value AS foo", mock_query, - "admin", mock_session, mock_cursor, None, @@ -895,7 +888,6 @@ def test_execute_sql_statements_ctas( rendered_query=sql, return_results=True, store_results=False, - user_name="admin", session=mock_session, start_time=None, expand_data=False, @@ -929,7 +921,6 @@ def test_execute_sql_statements_ctas( rendered_query=sql, return_results=True, store_results=False, - user_name="admin", session=mock_session, start_time=None, expand_data=False, diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index dcc2ded750fb6..9950fb9fedda5 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -20,6 +20,8 @@ from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session +from superset.utils.core import override_user + def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: """ @@ -46,7 +48,6 @@ def test_execute_sql_statement(mocker: MockerFixture, app: None) -> None: execute_sql_statement( sql_statement, query, - user_name=None, session=session, cursor=cursor, log_params={}, @@ -95,7 +96,6 @@ def test_execute_sql_statement_with_rls( execute_sql_statement( sql_statement, query, - user_name=None, session=session, cursor=cursor, log_params={}, @@ -153,16 +153,24 @@ def test_sql_lab_insert_rls( session.add(query) session.commit() - # first without RLS - superset_result_set = execute_sql_statement( - sql_statement=query.sql, - query=query, - user_name="admin", - session=session, - cursor=cursor, - log_params=None, - apply_ctas=False, + admin = User( + first_name="Alice", + last_name="Doe", + email="adoe@example.org", + username="admin", + roles=[Role(name="Admin")], ) + + # first without RLS + with override_user(admin): + superset_result_set = execute_sql_statement( + sql_statement=query.sql, + query=query, + session=session, + cursor=cursor, + log_params=None, + apply_ctas=False, + ) assert ( superset_result_set.to_pandas_df().to_markdown() == """ @@ -177,13 +185,6 @@ def test_sql_lab_insert_rls( assert query.executed_sql == "SELECT c FROM t\nLIMIT 6" # now with RLS - admin = User( - first_name="Alice", - last_name="Doe", - email="adoe@example.org", - username="admin", - roles=[Role(name="Admin")], - ) rls = RowLevelSecurityFilter( filter_type=RowLevelSecurityFilterType.REGULAR, tables=[SqlaTable(database_id=1, schema=None, table_name="t")], @@ -196,15 +197,15 @@ def test_sql_lab_insert_rls( mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin) mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) - superset_result_set = execute_sql_statement( - sql_statement=query.sql, - query=query, - user_name="admin", - session=session, - cursor=cursor, - log_params=None, - apply_ctas=False, - ) + with override_user(admin): + superset_result_set = execute_sql_statement( + sql_statement=query.sql, + query=query, + session=session, + cursor=cursor, + log_params=None, + apply_ctas=False, + ) assert ( superset_result_set.to_pandas_df().to_markdown() == """