From e1483277606ca296802b287e7db0b09f55df5c3d Mon Sep 17 00:00:00 2001 From: John Bodley <4567245+john-bodley@users.noreply.github.com> Date: Fri, 28 Jun 2024 12:33:56 -0700 Subject: [PATCH] chore(dao/command): Add transaction decorator to try to enforce "unit of work" (#24969) --- pyproject.toml | 1 + scripts/permissions_cleanup.py | 7 +- scripts/python_tests.sh | 1 + superset/cachekeys/api.py | 8 +- superset/cli/examples.py | 2 + superset/cli/main.py | 2 + superset/cli/test.py | 11 +- superset/cli/update.py | 3 + .../annotation_layer/annotation/create.py | 10 +- .../annotation_layer/annotation/delete.py | 11 +- .../annotation_layer/annotation/update.py | 12 +- superset/commands/annotation_layer/create.py | 10 +- superset/commands/annotation_layer/delete.py | 11 +- superset/commands/annotation_layer/update.py | 12 +- superset/commands/chart/create.py | 14 +-- superset/commands/chart/delete.py | 11 +- superset/commands/chart/importers/v1/utils.py | 2 +- superset/commands/chart/update.py | 27 ++-- superset/commands/css/delete.py | 11 +- superset/commands/dashboard/create.py | 13 +- superset/commands/dashboard/delete.py | 11 +- superset/commands/dashboard/importers/v0.py | 3 +- .../commands/dashboard/importers/v1/utils.py | 2 +- .../commands/dashboard/permalink/create.py | 54 ++++---- superset/commands/dashboard/update.py | 33 ++--- superset/commands/database/create.py | 14 +-- superset/commands/database/delete.py | 11 +- .../commands/database/ssh_tunnel/create.py | 11 +- .../commands/database/ssh_tunnel/delete.py | 10 +- .../commands/database/ssh_tunnel/update.py | 29 +++-- superset/commands/database/update.py | 32 ++--- superset/commands/database/uploaders/base.py | 10 +- superset/commands/dataset/columns/delete.py | 11 +- superset/commands/dataset/create.py | 19 +-- superset/commands/dataset/delete.py | 11 +- superset/commands/dataset/duplicate.py | 115 +++++++++--------- superset/commands/dataset/importers/v0.py | 5 +- .../commands/dataset/importers/v1/utils.py | 2 +- superset/commands/dataset/metrics/delete.py | 11 +- superset/commands/dataset/refresh.py | 14 +-- superset/commands/dataset/update.py | 27 ++-- superset/commands/explore/permalink/create.py | 66 +++++----- superset/commands/importers/v1/__init__.py | 6 +- superset/commands/importers/v1/assets.py | 18 +-- superset/commands/importers/v1/examples.py | 5 +- superset/commands/key_value/create.py | 13 +- superset/commands/key_value/delete.py | 19 ++- superset/commands/key_value/delete_expired.py | 11 +- superset/commands/key_value/update.py | 11 +- superset/commands/key_value/upsert.py | 26 ++-- superset/commands/query/delete.py | 11 +- superset/commands/report/create.py | 10 +- superset/commands/report/delete.py | 11 +- superset/commands/report/execute.py | 7 +- superset/commands/report/log_prune.py | 14 +-- superset/commands/report/update.py | 13 +- superset/commands/security/create.py | 9 +- superset/commands/security/delete.py | 10 +- superset/commands/security/update.py | 12 +- superset/commands/sql_lab/execute.py | 22 +++- superset/commands/tag/create.py | 54 ++++---- superset/commands/tag/delete.py | 27 ++-- superset/commands/tag/update.py | 20 +-- superset/commands/temporary_cache/create.py | 11 +- superset/commands/temporary_cache/delete.py | 11 +- superset/commands/temporary_cache/update.py | 11 +- superset/connectors/sqla/models.py | 5 +- superset/daos/base.py | 48 ++------ superset/daos/chart.py | 2 - superset/daos/dashboard.py | 12 +- superset/daos/database.py | 6 +- superset/daos/dataset.py | 17 +-- superset/daos/exceptions.py | 24 ---- superset/daos/query.py | 2 - superset/daos/report.py | 40 ++---- superset/daos/tag.py | 66 ++-------- superset/daos/user.py | 1 - superset/dashboards/api.py | 10 +- superset/databases/api.py | 3 +- superset/db_engine_specs/gsheets.py | 2 +- superset/db_engine_specs/hive.py | 2 +- superset/db_engine_specs/impala.py | 2 +- superset/db_engine_specs/presto.py | 2 +- superset/db_engine_specs/trino.py | 1 + superset/examples/bart_lines.py | 1 - superset/examples/birth_names.py | 3 - superset/examples/country_map.py | 1 - superset/examples/css_templates.py | 2 - superset/examples/deck.py | 1 - superset/examples/energy.py | 3 - superset/examples/flights.py | 1 - superset/examples/helpers.py | 3 - superset/examples/long_lat.py | 1 - superset/examples/misc_dashboard.py | 1 - superset/examples/multiformat_time_series.py | 1 - superset/examples/paris.py | 1 - superset/examples/random_time_series.py | 2 - superset/examples/sf_population_polygons.py | 1 - .../examples/supported_charts_dashboard.py | 3 - superset/examples/tabbed_dashboard.py | 3 - superset/examples/world_bank.py | 6 +- superset/extensions/metastore_cache.py | 7 +- superset/extensions/pylint.py | 17 +++ superset/initialization/__init__.py | 2 + superset/key_value/shared_entries.py | 2 - superset/models/dashboard.py | 2 +- superset/queries/api.py | 4 +- superset/row_level_security/api.py | 6 +- superset/security/manager.py | 4 - superset/sql_lab.py | 2 + superset/sqllab/sql_json_executer.py | 3 + superset/tags/models.py | 1 + superset/tasks/celery_app.py | 2 +- superset/utils/database.py | 5 +- superset/utils/decorators.py | 63 ++++++++++ superset/utils/lock.py | 4 - superset/utils/log.py | 2 +- superset/views/base.py | 3 +- superset/views/core.py | 4 +- superset/views/dashboard/views.py | 2 +- superset/views/datasource/views.py | 2 +- superset/views/key_value.py | 2 +- superset/views/sql_lab/views.py | 2 +- tests/integration_tests/base_tests.py | 3 +- tests/integration_tests/charts/api_tests.py | 1 - .../charts/data/api_tests.py | 3 + tests/integration_tests/conftest.py | 4 - tests/integration_tests/core_tests.py | 2 +- tests/integration_tests/dashboard_tests.py | 6 +- .../dashboards/commands_tests.py | 1 - .../integration_tests/databases/api_tests.py | 3 - tests/integration_tests/datasets/api_tests.py | 17 +-- tests/integration_tests/datasource_tests.py | 2 - tests/integration_tests/embedded/api_tests.py | 1 + tests/integration_tests/embedded/dao_tests.py | 6 +- tests/integration_tests/embedded/test_view.py | 2 + .../fixtures/unicode_dashboard.py | 5 +- .../security/row_level_security_tests.py | 2 - tests/integration_tests/sqla_models_tests.py | 3 +- tests/integration_tests/sqllab_tests.py | 1 - .../integration_tests/superset_test_config.py | 1 + tests/integration_tests/tags/dao_tests.py | 3 +- .../commands/databases/create_test.py | 2 - .../commands/databases/update_test.py | 4 - tests/unit_tests/dao/tag_test.py | 7 -- tests/unit_tests/dao/user_test.py | 1 - tests/unit_tests/databases/api_test.py | 2 +- .../ssh_tunnel/commands/create_test.py | 2 +- .../databases/ssh_tunnel/dao_tests.py | 1 - tests/unit_tests/security/manager_test.py | 1 - tests/unit_tests/utils/lock_tests.py | 51 ++++---- 151 files changed, 682 insertions(+), 917 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7717611308361..efb211c0d34fe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -240,6 +240,7 @@ ignore_basepython_conflict = true commands = superset db upgrade superset init + superset load-test-users # use -s to be able to use break pointers. # no args or tests/* can be passed as an argument to run all tests pytest -s {posargs} diff --git a/scripts/permissions_cleanup.py b/scripts/permissions_cleanup.py index 0416f55806821..22e58f013fa3e 100644 --- a/scripts/permissions_cleanup.py +++ b/scripts/permissions_cleanup.py @@ -17,8 +17,10 @@ from collections import defaultdict from superset import security_manager +from superset.utils.decorators import transaction +@transaction() def cleanup_permissions() -> None: # 1. Clean up duplicates. pvms = security_manager.get_session.query( @@ -29,7 +31,6 @@ def cleanup_permissions() -> None: for pvm in pvms: pvms_dict[(pvm.permission, pvm.view_menu)].append(pvm) duplicates = [v for v in pvms_dict.values() if len(v) > 1] - len(duplicates) for pvm_list in duplicates: first_prm = pvm_list[0] @@ -38,7 +39,6 @@ def cleanup_permissions() -> None: roles = roles.union(pvm.role) security_manager.get_session.delete(pvm) first_prm.roles = list(roles) - security_manager.get_session.commit() pvms = security_manager.get_session.query( security_manager.permissionview_model @@ -52,7 +52,6 @@ def cleanup_permissions() -> None: for pvm in pvms: if not (pvm.view_menu and pvm.permission): security_manager.get_session.delete(pvm) - security_manager.get_session.commit() pvms = security_manager.get_session.query( security_manager.permissionview_model @@ -63,7 +62,6 @@ def cleanup_permissions() -> None: roles = security_manager.get_session.query(security_manager.role_model).all() for role in roles: role.permissions = [p for p in role.permissions if p] - security_manager.get_session.commit() # 4. Delete empty roles from permission view menus pvms = security_manager.get_session.query( @@ -71,7 +69,6 @@ def cleanup_permissions() -> None: ).all() for pvm in pvms: pvm.role = [r for r in pvm.role if r] - security_manager.get_session.commit() cleanup_permissions() diff --git a/scripts/python_tests.sh b/scripts/python_tests.sh index c3f27d17f78c4..e127d0c020621 100755 --- a/scripts/python_tests.sh +++ b/scripts/python_tests.sh @@ -29,6 +29,7 @@ echo "Superset config module: $SUPERSET_CONFIG" superset db upgrade superset init +superset load-test-users echo "Running tests" diff --git a/superset/cachekeys/api.py b/superset/cachekeys/api.py index 91cae29b8dc5a..093d81b1c3f7d 100644 --- a/superset/cachekeys/api.py +++ b/superset/cachekeys/api.py @@ -113,8 +113,10 @@ def invalidate(self) -> Response: delete_stmt = CacheKey.__table__.delete().where( # pylint: disable=no-member CacheKey.cache_key.in_(cache_keys) ) - db.session.execute(delete_stmt) - db.session.commit() + + with db.session.begin_nested(): + db.session.execute(delete_stmt) + stats_logger_manager.instance.gauge( "invalidated_cache", len(cache_keys) ) @@ -125,7 +127,5 @@ def invalidate(self) -> Response: ) except SQLAlchemyError as ex: # pragma: no cover logger.error(ex, exc_info=True) - db.session.rollback() return self.response_500(str(ex)) - db.session.commit() return self.response(201) diff --git a/superset/cli/examples.py b/superset/cli/examples.py index 3ce136ada7bc5..51b89f9641000 100755 --- a/superset/cli/examples.py +++ b/superset/cli/examples.py @@ -20,6 +20,7 @@ from flask.cli import with_appcontext import superset.utils.database as database_utils +from superset.utils.decorators import transaction logger = logging.getLogger(__name__) @@ -89,6 +90,7 @@ def load_examples_run( @click.command() @with_appcontext +@transaction() @click.option("--load-test-data", "-t", is_flag=True, help="Load additional test data") @click.option("--load-big-data", "-b", is_flag=True, help="Load additional big data") @click.option( diff --git a/superset/cli/main.py b/superset/cli/main.py index aa7e3068f8b9d..ffe3278b11a0b 100755 --- a/superset/cli/main.py +++ b/superset/cli/main.py @@ -27,6 +27,7 @@ from superset import app, appbuilder, cli, security_manager from superset.cli.lib import normalize_token from superset.extensions import db +from superset.utils.decorators import transaction logger = logging.getLogger(__name__) @@ -60,6 +61,7 @@ def make_shell_context() -> dict[str, Any]: @superset.command() @with_appcontext +@transaction() def init() -> None: """Inits the Superset application""" appbuilder.add_permissions(update_perms=True) diff --git a/superset/cli/test.py b/superset/cli/test.py index f175acec470cd..60ea532cbdba4 100755 --- a/superset/cli/test.py +++ b/superset/cli/test.py @@ -22,12 +22,14 @@ import superset.utils.database as database_utils from superset import app, security_manager +from superset.utils.decorators import transaction logger = logging.getLogger(__name__) @click.command() @with_appcontext +@transaction() def load_test_users() -> None: """ Loads admin, alpha, and gamma user for testing purposes @@ -35,15 +37,7 @@ def load_test_users() -> None: Syncs permissions for those users/roles """ print(Fore.GREEN + "Loading a set of users for unit tests") - load_test_users_run() - -def load_test_users_run() -> None: - """ - Loads admin, alpha, and gamma user for testing purposes - - Syncs permissions for those users/roles - """ if app.config["TESTING"]: sm = security_manager @@ -84,4 +78,3 @@ def load_test_users_run() -> None: sm.find_role(role), password="general", ) - sm.get_session.commit() diff --git a/superset/cli/update.py b/superset/cli/update.py index 9ff1f3bf58bfc..c162bb1e56eab 100755 --- a/superset/cli/update.py +++ b/superset/cli/update.py @@ -30,6 +30,7 @@ from flask_appbuilder.api.manager import resolver import superset.utils.database as database_utils +from superset.utils.decorators import transaction from superset.utils.encrypt import SecretsMigrator logger = logging.getLogger(__name__) @@ -37,6 +38,7 @@ @click.command() @with_appcontext +@transaction() @click.option("--database_name", "-d", help="Database name to change") @click.option("--uri", "-u", help="Database URI to change") @click.option( @@ -53,6 +55,7 @@ def set_database_uri(database_name: str, uri: str, skip_create: bool) -> None: @click.command() @with_appcontext +@transaction() def sync_tags() -> None: """Rebuilds special tags (owner, type, favorited by).""" # pylint: disable=no-member diff --git a/superset/commands/annotation_layer/annotation/create.py b/superset/commands/annotation_layer/annotation/create.py index feed6162cacbe..409efd33421ac 100644 --- a/superset/commands/annotation_layer/annotation/create.py +++ b/superset/commands/annotation_layer/annotation/create.py @@ -16,6 +16,7 @@ # under the License. import logging from datetime import datetime +from functools import partial from typing import Any, Optional from flask_appbuilder.models.sqla import Model @@ -30,7 +31,7 @@ from superset.commands.annotation_layer.exceptions import AnnotationLayerNotFoundError from superset.commands.base import BaseCommand from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO -from superset.daos.exceptions import DAOCreateFailedError +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -39,13 +40,10 @@ class CreateAnnotationCommand(BaseCommand): def __init__(self, data: dict[str, Any]): self._properties = data.copy() + @transaction(on_error=partial(on_error, reraise=AnnotationCreateFailedError)) def run(self) -> Model: self.validate() - try: - return AnnotationDAO.create(attributes=self._properties) - except DAOCreateFailedError as ex: - logger.exception(ex.exception) - raise AnnotationCreateFailedError() from ex + return AnnotationDAO.create(attributes=self._properties) def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/commands/annotation_layer/annotation/delete.py b/superset/commands/annotation_layer/annotation/delete.py index 3f48ae2ceb120..125265449edeb 100644 --- a/superset/commands/annotation_layer/annotation/delete.py +++ b/superset/commands/annotation_layer/annotation/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from superset.commands.annotation_layer.annotation.exceptions import ( @@ -23,8 +24,8 @@ ) from superset.commands.base import BaseCommand from superset.daos.annotation_layer import AnnotationDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.models.annotations import Annotation +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -34,15 +35,11 @@ def __init__(self, model_ids: list[int]): self._model_ids = model_ids self._models: Optional[list[Annotation]] = None + @transaction(on_error=partial(on_error, reraise=AnnotationDeleteFailedError)) def run(self) -> None: self.validate() assert self._models - - try: - AnnotationDAO.delete(self._models) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise AnnotationDeleteFailedError() from ex + AnnotationDAO.delete(self._models) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/annotation_layer/annotation/update.py b/superset/commands/annotation_layer/annotation/update.py index 9ba07fdcd68d2..129b09fcb36fd 100644 --- a/superset/commands/annotation_layer/annotation/update.py +++ b/superset/commands/annotation_layer/annotation/update.py @@ -16,6 +16,7 @@ # under the License. import logging from datetime import datetime +from functools import partial from typing import Any, Optional from flask_appbuilder.models.sqla import Model @@ -31,8 +32,8 @@ from superset.commands.annotation_layer.exceptions import AnnotationLayerNotFoundError from superset.commands.base import BaseCommand from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO -from superset.daos.exceptions import DAOUpdateFailedError from superset.models.annotations import Annotation +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -43,16 +44,11 @@ def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model: Optional[Annotation] = None + @transaction(on_error=partial(on_error, reraise=AnnotationUpdateFailedError)) def run(self) -> Model: self.validate() assert self._model - - try: - annotation = AnnotationDAO.update(self._model, self._properties) - except DAOUpdateFailedError as ex: - logger.exception(ex.exception) - raise AnnotationUpdateFailedError() from ex - return annotation + return AnnotationDAO.update(self._model, self._properties) def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/commands/annotation_layer/create.py b/superset/commands/annotation_layer/create.py index 6b87ad570363a..0f06e2b2744d7 100644 --- a/superset/commands/annotation_layer/create.py +++ b/superset/commands/annotation_layer/create.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any from flask_appbuilder.models.sqla import Model @@ -27,7 +28,7 @@ ) from superset.commands.base import BaseCommand from superset.daos.annotation_layer import AnnotationLayerDAO -from superset.daos.exceptions import DAOCreateFailedError +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -36,13 +37,10 @@ class CreateAnnotationLayerCommand(BaseCommand): def __init__(self, data: dict[str, Any]): self._properties = data.copy() + @transaction(on_error=partial(on_error, reraise=AnnotationLayerCreateFailedError)) def run(self) -> Model: self.validate() - try: - return AnnotationLayerDAO.create(attributes=self._properties) - except DAOCreateFailedError as ex: - logger.exception(ex.exception) - raise AnnotationLayerCreateFailedError() from ex + return AnnotationLayerDAO.create(attributes=self._properties) def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/commands/annotation_layer/delete.py b/superset/commands/annotation_layer/delete.py index a75ee42b772e0..b97b7ac0933f2 100644 --- a/superset/commands/annotation_layer/delete.py +++ b/superset/commands/annotation_layer/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from superset.commands.annotation_layer.exceptions import ( @@ -24,8 +25,8 @@ ) from superset.commands.base import BaseCommand from superset.daos.annotation_layer import AnnotationLayerDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.models.annotations import AnnotationLayer +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -35,15 +36,11 @@ def __init__(self, model_ids: list[int]): self._model_ids = model_ids self._models: Optional[list[AnnotationLayer]] = None + @transaction(on_error=partial(on_error, reraise=AnnotationLayerDeleteFailedError)) def run(self) -> None: self.validate() assert self._models - - try: - AnnotationLayerDAO.delete(self._models) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise AnnotationLayerDeleteFailedError() from ex + AnnotationLayerDAO.delete(self._models) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/annotation_layer/update.py b/superset/commands/annotation_layer/update.py index d15440882b155..c4e18bdd09eeb 100644 --- a/superset/commands/annotation_layer/update.py +++ b/superset/commands/annotation_layer/update.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any, Optional from flask_appbuilder.models.sqla import Model @@ -28,8 +29,8 @@ ) from superset.commands.base import BaseCommand from superset.daos.annotation_layer import AnnotationLayerDAO -from superset.daos.exceptions import DAOUpdateFailedError from superset.models.annotations import AnnotationLayer +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -40,16 +41,11 @@ def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model: Optional[AnnotationLayer] = None + @transaction(on_error=partial(on_error, reraise=AnnotationLayerUpdateFailedError)) def run(self) -> Model: self.validate() assert self._model - - try: - annotation_layer = AnnotationLayerDAO.update(self._model, self._properties) - except DAOUpdateFailedError as ex: - logger.exception(ex.exception) - raise AnnotationLayerUpdateFailedError() from ex - return annotation_layer + return AnnotationLayerDAO.update(self._model, self._properties) def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/commands/chart/create.py b/superset/commands/chart/create.py index 2b251029c3f38..84b3aa29411ef 100644 --- a/superset/commands/chart/create.py +++ b/superset/commands/chart/create.py @@ -16,6 +16,7 @@ # under the License. import logging from datetime import datetime +from functools import partial from typing import Any, Optional from flask import g @@ -33,7 +34,7 @@ from superset.commands.utils import get_datasource_by_id from superset.daos.chart import ChartDAO from superset.daos.dashboard import DashboardDAO -from superset.daos.exceptions import DAOCreateFailedError +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -42,15 +43,12 @@ class CreateChartCommand(CreateMixin, BaseCommand): def __init__(self, data: dict[str, Any]): self._properties = data.copy() + @transaction(on_error=partial(on_error, reraise=ChartCreateFailedError)) def run(self) -> Model: self.validate() - try: - self._properties["last_saved_at"] = datetime.now() - self._properties["last_saved_by"] = g.user - return ChartDAO.create(attributes=self._properties) - except DAOCreateFailedError as ex: - logger.exception(ex.exception) - raise ChartCreateFailedError() from ex + self._properties["last_saved_at"] = datetime.now() + self._properties["last_saved_by"] = g.user + return ChartDAO.create(attributes=self._properties) def validate(self) -> None: exceptions = [] diff --git a/superset/commands/chart/delete.py b/superset/commands/chart/delete.py index 8694ae1feb32d..00e6d201bcc95 100644 --- a/superset/commands/chart/delete.py +++ b/superset/commands/chart/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from flask_babel import lazy_gettext as _ @@ -28,10 +29,10 @@ ChartNotFoundError, ) from superset.daos.chart import ChartDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.daos.report import ReportScheduleDAO from superset.exceptions import SupersetSecurityException from superset.models.slice import Slice +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -41,15 +42,11 @@ def __init__(self, model_ids: list[int]): self._model_ids = model_ids self._models: Optional[list[Slice]] = None + @transaction(on_error=partial(on_error, reraise=ChartDeleteFailedError)) def run(self) -> None: self.validate() assert self._models - - try: - ChartDAO.delete(self._models) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise ChartDeleteFailedError() from ex + ChartDAO.delete(self._models) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/chart/importers/v1/utils.py b/superset/commands/chart/importers/v1/utils.py index 39ca49a5d5ffc..35a7f6e2700f3 100644 --- a/superset/commands/chart/importers/v1/utils.py +++ b/superset/commands/chart/importers/v1/utils.py @@ -77,7 +77,7 @@ def import_chart( if chart.id is None: db.session.flush() - if user := get_user(): + if (user := get_user()) and user not in chart.owners: chart.owners.append(user) return chart diff --git a/superset/commands/chart/update.py b/superset/commands/chart/update.py index 74b1c30aa83c8..d6b212d5ce861 100644 --- a/superset/commands/chart/update.py +++ b/superset/commands/chart/update.py @@ -16,6 +16,7 @@ # under the License. import logging from datetime import datetime +from functools import partial from typing import Any, Optional from flask import g @@ -35,10 +36,10 @@ from superset.commands.utils import get_datasource_by_id, update_tags, validate_tags from superset.daos.chart import ChartDAO from superset.daos.dashboard import DashboardDAO -from superset.daos.exceptions import DAODeleteFailedError, DAOUpdateFailedError from superset.exceptions import SupersetSecurityException from superset.models.slice import Slice from superset.tags.models import ObjectType +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -55,24 +56,20 @@ def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model: Optional[Slice] = None + @transaction(on_error=partial(on_error, reraise=ChartUpdateFailedError)) def run(self) -> Model: self.validate() assert self._model - try: - # Update tags - tags = self._properties.pop("tags", None) - if tags is not None: - update_tags(ObjectType.chart, self._model.id, self._model.tags, tags) - - if self._properties.get("query_context_generation") is None: - self._properties["last_saved_at"] = datetime.now() - self._properties["last_saved_by"] = g.user - chart = ChartDAO.update(self._model, self._properties) - except (DAOUpdateFailedError, DAODeleteFailedError) as ex: - logger.exception(ex.exception) - raise ChartUpdateFailedError() from ex - return chart + # Update tags + if (tags := self._properties.pop("tags", None)) is not None: + update_tags(ObjectType.chart, self._model.id, self._model.tags, tags) + + if self._properties.get("query_context_generation") is None: + self._properties["last_saved_at"] = datetime.now() + self._properties["last_saved_by"] = g.user + + return ChartDAO.update(self._model, self._properties) def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/commands/css/delete.py b/superset/commands/css/delete.py index b8362f6b464dd..c6559eb06665b 100644 --- a/superset/commands/css/delete.py +++ b/superset/commands/css/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from superset.commands.base import BaseCommand @@ -23,8 +24,8 @@ CssTemplateNotFoundError, ) from superset.daos.css import CssTemplateDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.models.core import CssTemplate +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -34,15 +35,11 @@ def __init__(self, model_ids: list[int]): self._model_ids = model_ids self._models: Optional[list[CssTemplate]] = None + @transaction(on_error=partial(on_error, reraise=CssTemplateDeleteFailedError)) def run(self) -> None: self.validate() assert self._models - - try: - CssTemplateDAO.delete(self._models) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise CssTemplateDeleteFailedError() from ex + CssTemplateDAO.delete(self._models) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/dashboard/create.py b/superset/commands/dashboard/create.py index 1745391238d75..469d3d81af25c 100644 --- a/superset/commands/dashboard/create.py +++ b/superset/commands/dashboard/create.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any, Optional from flask_appbuilder.models.sqla import Model @@ -28,23 +29,19 @@ ) from superset.commands.utils import populate_roles from superset.daos.dashboard import DashboardDAO -from superset.daos.exceptions import DAOCreateFailedError +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) class CreateDashboardCommand(CreateMixin, BaseCommand): - def __init__(self, data: dict[str, Any]): + def __init__(self, data: dict[str, Any]) -> None: self._properties = data.copy() + @transaction(on_error=partial(on_error, reraise=DashboardCreateFailedError)) def run(self) -> Model: self.validate() - try: - dashboard = DashboardDAO.create(attributes=self._properties, commit=True) - except DAOCreateFailedError as ex: - logger.exception(ex.exception) - raise DashboardCreateFailedError() from ex - return dashboard + return DashboardDAO.create(attributes=self._properties) def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/commands/dashboard/delete.py b/superset/commands/dashboard/delete.py index 569d05dac74de..0135c4303f292 100644 --- a/superset/commands/dashboard/delete.py +++ b/superset/commands/dashboard/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from flask_babel import lazy_gettext as _ @@ -28,10 +29,10 @@ DashboardNotFoundError, ) from superset.daos.dashboard import DashboardDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.daos.report import ReportScheduleDAO from superset.exceptions import SupersetSecurityException from superset.models.dashboard import Dashboard +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -41,15 +42,11 @@ def __init__(self, model_ids: list[int]): self._model_ids = model_ids self._models: Optional[list[Dashboard]] = None + @transaction(on_error=partial(on_error, reraise=DashboardDeleteFailedError)) def run(self) -> None: self.validate() assert self._models - - try: - DashboardDAO.delete(self._models) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise DashboardDeleteFailedError() from ex + DashboardDAO.delete(self._models) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/dashboard/importers/v0.py b/superset/commands/dashboard/importers/v0.py index a9ee3e484e1c3..99090e7d417fa 100644 --- a/superset/commands/dashboard/importers/v0.py +++ b/superset/commands/dashboard/importers/v0.py @@ -36,6 +36,7 @@ convert_filter_scopes, copy_filter_scopes, ) +from superset.utils.decorators import transaction logger = logging.getLogger(__name__) @@ -311,7 +312,6 @@ def import_dashboards( for dashboard in data["dashboards"]: import_dashboard(dashboard, dataset_id_mapping, import_time=import_time) - db.session.commit() class ImportDashboardsCommand(BaseCommand): @@ -329,6 +329,7 @@ def __init__( self.contents = contents self.database_id = database_id + @transaction() def run(self) -> None: self.validate() diff --git a/superset/commands/dashboard/importers/v1/utils.py b/superset/commands/dashboard/importers/v1/utils.py index f10afd12bc9ee..5e949093b8a80 100644 --- a/superset/commands/dashboard/importers/v1/utils.py +++ b/superset/commands/dashboard/importers/v1/utils.py @@ -188,7 +188,7 @@ def import_dashboard( if dashboard.id is None: db.session.flush() - if user := get_user(): + if (user := get_user()) and user not in dashboard.owners: dashboard.owners.append(user) return dashboard diff --git a/superset/commands/dashboard/permalink/create.py b/superset/commands/dashboard/permalink/create.py index 76b7b8e83453c..7d08f78e9a9be 100644 --- a/superset/commands/dashboard/permalink/create.py +++ b/superset/commands/dashboard/permalink/create.py @@ -15,18 +15,22 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from sqlalchemy.exc import SQLAlchemyError -from superset import db from superset.commands.dashboard.permalink.base import BaseDashboardPermalinkCommand from superset.commands.key_value.upsert import UpsertKeyValueCommand from superset.daos.dashboard import DashboardDAO from superset.dashboards.permalink.exceptions import DashboardPermalinkCreateFailedError from superset.dashboards.permalink.types import DashboardPermalinkState -from superset.key_value.exceptions import KeyValueCodecEncodeException +from superset.key_value.exceptions import ( + KeyValueCodecEncodeException, + KeyValueUpsertFailedError, +) from superset.key_value.utils import encode_permalink_key, get_deterministic_uuid from superset.utils.core import get_user_id +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -47,29 +51,33 @@ def __init__( self.dashboard_id = dashboard_id self.state = state + @transaction( + on_error=partial( + on_error, + catches=( + KeyValueCodecEncodeException, + KeyValueUpsertFailedError, + SQLAlchemyError, + ), + reraise=DashboardPermalinkCreateFailedError, + ), + ) def run(self) -> str: self.validate() - try: - dashboard = DashboardDAO.get_by_id_or_slug(self.dashboard_id) - value = { - "dashboardId": str(dashboard.uuid), - "state": self.state, - } - user_id = get_user_id() - key = UpsertKeyValueCommand( - resource=self.resource, - key=get_deterministic_uuid(self.salt, (user_id, value)), - value=value, - codec=self.codec, - ).run() - assert key.id # for type checks - db.session.commit() - return encode_permalink_key(key=key.id, salt=self.salt) - except KeyValueCodecEncodeException as ex: - raise DashboardPermalinkCreateFailedError(str(ex)) from ex - except SQLAlchemyError as ex: - logger.exception("Error running create command") - raise DashboardPermalinkCreateFailedError() from ex + dashboard = DashboardDAO.get_by_id_or_slug(self.dashboard_id) + value = { + "dashboardId": str(dashboard.uuid), + "state": self.state, + } + user_id = get_user_id() + key = UpsertKeyValueCommand( + resource=self.resource, + key=get_deterministic_uuid(self.salt, (user_id, value)), + value=value, + codec=self.codec, + ).run() + assert key.id # for type checks + return encode_permalink_key(key=key.id, salt=self.salt) def validate(self) -> None: pass diff --git a/superset/commands/dashboard/update.py b/superset/commands/dashboard/update.py index 890422602dd68..2effd7bd2ece1 100644 --- a/superset/commands/dashboard/update.py +++ b/superset/commands/dashboard/update.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any, Optional from flask_appbuilder.models.sqla import Model @@ -31,12 +32,11 @@ ) from superset.commands.utils import populate_roles, update_tags, validate_tags from superset.daos.dashboard import DashboardDAO -from superset.daos.exceptions import DAODeleteFailedError, DAOUpdateFailedError from superset.exceptions import SupersetSecurityException -from superset.extensions import db from superset.models.dashboard import Dashboard from superset.tags.models import ObjectType from superset.utils import json +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -47,29 +47,22 @@ def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model: Optional[Dashboard] = None + @transaction(on_error=partial(on_error, reraise=DashboardUpdateFailedError)) def run(self) -> Model: self.validate() assert self._model - try: - # Update tags - tags = self._properties.pop("tags", None) - if tags is not None: - update_tags( - ObjectType.dashboard, self._model.id, self._model.tags, tags - ) + # Update tags + if (tags := self._properties.pop("tags", None)) is not None: + update_tags(ObjectType.dashboard, self._model.id, self._model.tags, tags) + + dashboard = DashboardDAO.update(self._model, self._properties) + if self._properties.get("json_metadata"): + DashboardDAO.set_dash_metadata( + dashboard, + data=json.loads(self._properties.get("json_metadata", "{}")), + ) - dashboard = DashboardDAO.update(self._model, self._properties, commit=False) - if self._properties.get("json_metadata"): - dashboard = DashboardDAO.set_dash_metadata( - dashboard, - data=json.loads(self._properties.get("json_metadata", "{}")), - commit=False, - ) - db.session.commit() - except (DAOUpdateFailedError, DAODeleteFailedError) as ex: - logger.exception(ex.exception) - raise DashboardUpdateFailedError() from ex return dashboard def validate(self) -> None: diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index e66e1110c8dff..76dd6087be58a 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any, Optional from flask import current_app @@ -39,11 +40,11 @@ ) from superset.commands.database.test_connection import TestConnectionDatabaseCommand from superset.daos.database import DatabaseDAO -from superset.daos.exceptions import DAOCreateFailedError from superset.databases.ssh_tunnel.models import SSHTunnel from superset.exceptions import SupersetErrorsException -from superset.extensions import db, event_logger, security_manager +from superset.extensions import event_logger, security_manager from superset.models.core import Database +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) stats_logger = current_app.config["STATS_LOGGER"] @@ -53,6 +54,7 @@ class CreateDatabaseCommand(BaseCommand): def __init__(self, data: dict[str, Any]): self._properties = data.copy() + @transaction(on_error=partial(on_error, reraise=DatabaseCreateFailedError)) def run(self) -> Model: self.validate() @@ -96,8 +98,6 @@ def run(self) -> Model: database, ssh_tunnel_properties ).run() - db.session.commit() - # add catalog/schema permissions if database.db_engine_spec.supports_catalog: catalogs = database.get_all_catalog_names( @@ -121,14 +121,12 @@ def run(self) -> Model: except Exception: # pylint: disable=broad-except logger.warning("Error processing catalog '%s'", catalog) continue - except ( SSHTunnelInvalidError, SSHTunnelCreateFailedError, SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError, ) as ex: - db.session.rollback() event_logger.log_with_context( action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel", engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], @@ -136,11 +134,9 @@ def run(self) -> Model: # So we can show the original message raise except ( - DAOCreateFailedError, DatabaseInvalidError, Exception, ) as ex: - db.session.rollback() event_logger.log_with_context( action=f"db_creation_failed.{ex.__class__.__name__}", engine=database.db_engine_spec.__name__, @@ -198,6 +194,6 @@ def validate(self) -> None: raise exception def _create_database(self) -> Database: - database = DatabaseDAO.create(attributes=self._properties, commit=False) + database = DatabaseDAO.create(attributes=self._properties) database.set_sqlalchemy_uri(database.sqlalchemy_uri) return database diff --git a/superset/commands/database/delete.py b/superset/commands/database/delete.py index ce0775506c3a9..bf499dac4ff49 100644 --- a/superset/commands/database/delete.py +++ b/superset/commands/database/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from flask_babel import lazy_gettext as _ @@ -27,9 +28,9 @@ DatabaseNotFoundError, ) from superset.daos.database import DatabaseDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.daos.report import ReportScheduleDAO from superset.models.core import Database +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -39,15 +40,11 @@ def __init__(self, model_id: int): self._model_id = model_id self._model: Optional[Database] = None + @transaction(on_error=partial(on_error, reraise=DatabaseDeleteFailedError)) def run(self) -> None: self.validate() assert self._model - - try: - DatabaseDAO.delete([self._model]) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise DatabaseDeleteFailedError() from ex + DatabaseDAO.delete([self._model]) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index 40083b4b648aa..89e607ba67aed 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any, Optional from flask_appbuilder.models.sqla import Model @@ -28,10 +29,10 @@ SSHTunnelRequiredFieldValidationError, ) from superset.daos.database import SSHTunnelDAO -from superset.daos.exceptions import DAOCreateFailedError from superset.databases.utils import make_url_safe from superset.extensions import event_logger from superset.models.core import Database +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -44,6 +45,7 @@ def __init__(self, database: Database, data: dict[str, Any]): self._properties["database"] = database self._database = database + @transaction(on_error=partial(on_error, reraise=SSHTunnelCreateFailedError)) def run(self) -> Model: """ Create an SSH tunnel. @@ -53,11 +55,8 @@ def run(self) -> Model: :raises SSHTunnelInvalidError: If the configuration are invalid """ - try: - self.validate() - return SSHTunnelDAO.create(attributes=self._properties, commit=False) - except DAOCreateFailedError as ex: - raise SSHTunnelCreateFailedError() from ex + self.validate() + return SSHTunnelDAO.create(attributes=self._properties) def validate(self) -> None: # TODO(hughhh): check to make sure the server port is not localhost diff --git a/superset/commands/database/ssh_tunnel/delete.py b/superset/commands/database/ssh_tunnel/delete.py index b8919e6d7bae6..8c742307aa8d9 100644 --- a/superset/commands/database/ssh_tunnel/delete.py +++ b/superset/commands/database/ssh_tunnel/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from superset import is_feature_enabled @@ -25,8 +26,8 @@ SSHTunnelNotFoundError, ) from superset.daos.database import SSHTunnelDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.databases.ssh_tunnel.models import SSHTunnel +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -36,16 +37,13 @@ def __init__(self, model_id: int): self._model_id = model_id self._model: Optional[SSHTunnel] = None + @transaction(on_error=partial(on_error, reraise=SSHTunnelDeleteFailedError)) def run(self) -> None: if not is_feature_enabled("SSH_TUNNELING"): raise SSHTunnelingNotEnabledError() self.validate() assert self._model - - try: - SSHTunnelDAO.delete([self._model]) - except DAODeleteFailedError as ex: - raise SSHTunnelDeleteFailedError() from ex + SSHTunnelDAO.delete([self._model]) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/database/ssh_tunnel/update.py b/superset/commands/database/ssh_tunnel/update.py index d0dd14a5b2372..b2fa416bd597e 100644 --- a/superset/commands/database/ssh_tunnel/update.py +++ b/superset/commands/database/ssh_tunnel/update.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any, Optional from flask_appbuilder.models.sqla import Model @@ -28,9 +29,9 @@ SSHTunnelUpdateFailedError, ) from superset.daos.database import SSHTunnelDAO -from superset.daos.exceptions import DAOUpdateFailedError from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -41,25 +42,23 @@ def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._model: Optional[SSHTunnel] = None + @transaction(on_error=partial(on_error, reraise=SSHTunnelUpdateFailedError)) def run(self) -> Optional[Model]: self.validate() - try: - if self._model is None: - return None - # unset password if private key is provided - if self._properties.get("private_key"): - self._properties["password"] = None + if self._model is None: + return None - # unset private key and password if password is provided - if self._properties.get("password"): - self._properties["private_key"] = None - self._properties["private_key_password"] = None + # unset password if private key is provided + if self._properties.get("private_key"): + self._properties["password"] = None - tunnel = SSHTunnelDAO.update(self._model, self._properties) - return tunnel - except DAOUpdateFailedError as ex: - raise SSHTunnelUpdateFailedError() from ex + # unset private key and password if password is provided + if self._properties.get("password"): + self._properties["private_key"] = None + self._properties["private_key_password"] = None + + return SSHTunnelDAO.update(self._model, self._properties) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 61b0d51ed826d..28f895b2f632b 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -18,6 +18,7 @@ from __future__ import annotations import logging +from functools import partial from typing import Any from flask_appbuilder.models.sqla import Model @@ -34,16 +35,14 @@ from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand from superset.commands.database.ssh_tunnel.exceptions import ( - SSHTunnelError, SSHTunnelingNotEnabledError, ) from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand from superset.daos.database import DatabaseDAO from superset.daos.dataset import DatasetDAO -from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError from superset.databases.ssh_tunnel.models import SSHTunnel -from superset.extensions import db from superset.models.core import Database +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -56,6 +55,7 @@ def __init__(self, model_id: int, data: dict[str, Any]): self._model_id = model_id self._model: Database | None = None + @transaction(on_error=partial(on_error, reraise=DatabaseUpdateFailedError)) def run(self) -> Model: self._model = DatabaseDAO.find_by_id(self._model_id) @@ -76,21 +76,10 @@ def run(self) -> Model: # since they're name based original_database_name = self._model.database_name - try: - database = DatabaseDAO.update( - self._model, - self._properties, - commit=False, - ) - database.set_sqlalchemy_uri(database.sqlalchemy_uri) - ssh_tunnel = self._handle_ssh_tunnel(database) - self._refresh_catalogs(database, original_database_name, ssh_tunnel) - except SSHTunnelError: # pylint: disable=try-except-raise - # allow exception to bubble for debugbing information - raise - except (DAOUpdateFailedError, DAOCreateFailedError) as ex: - raise DatabaseUpdateFailedError() from ex - + database = DatabaseDAO.update(self._model, self._properties) + database.set_sqlalchemy_uri(database.sqlalchemy_uri) + ssh_tunnel = self._handle_ssh_tunnel(database) + self._refresh_catalogs(database, original_database_name, ssh_tunnel) return database def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None: @@ -101,7 +90,6 @@ def _handle_ssh_tunnel(self, database: Database) -> SSHTunnel | None: return None if not is_feature_enabled("SSH_TUNNELING"): - db.session.rollback() raise SSHTunnelingNotEnabledError() current_ssh_tunnel = DatabaseDAO.get_ssh_tunnel(database.id) @@ -131,13 +119,13 @@ def _get_catalog_names( This method captures a generic exception, since errors could potentially come from any of the 50+ database drivers we support. """ + try: return database.get_all_catalog_names( force=True, ssh_tunnel=ssh_tunnel, ) except Exception as ex: - db.session.rollback() raise DatabaseConnectionFailedError() from ex def _get_schema_names( @@ -152,6 +140,7 @@ def _get_schema_names( This method captures a generic exception, since errors could potentially come from any of the 50+ database drivers we support. """ + try: return database.get_all_schema_names( force=True, @@ -159,7 +148,6 @@ def _get_schema_names( ssh_tunnel=ssh_tunnel, ) except Exception as ex: - db.session.rollback() raise DatabaseConnectionFailedError() from ex def _refresh_catalogs( @@ -225,8 +213,6 @@ def _refresh_catalogs( schemas, ) - db.session.commit() - def _refresh_schemas( self, database: Database, diff --git a/superset/commands/database/uploaders/base.py b/superset/commands/database/uploaders/base.py index b113e9ebf45d9..0e939ef4324da 100644 --- a/superset/commands/database/uploaders/base.py +++ b/superset/commands/database/uploaders/base.py @@ -16,11 +16,11 @@ # under the License. import logging from abc import abstractmethod +from functools import partial from typing import Any, Optional, TypedDict import pandas as pd from flask_babel import lazy_gettext as _ -from sqlalchemy.exc import SQLAlchemyError from werkzeug.datastructures import FileStorage from superset import db @@ -37,6 +37,7 @@ from superset.models.core import Database from superset.sql_parse import Table from superset.utils.core import get_user +from superset.utils.decorators import on_error, transaction from superset.views.database.validators import schema_allows_file_upload logger = logging.getLogger(__name__) @@ -144,6 +145,7 @@ def __init__( # pylint: disable=too-many-arguments self._file = file self._reader = reader + @transaction(on_error=partial(on_error, reraise=DatabaseUploadSaveMetadataFailed)) def run(self) -> None: self.validate() if not self._model: @@ -172,12 +174,6 @@ def run(self) -> None: sqla_table.fetch_metadata() - try: - db.session.commit() - except SQLAlchemyError as ex: - db.session.rollback() - raise DatabaseUploadSaveMetadataFailed() from ex - def validate(self) -> None: self._model = DatabaseDAO.find_by_id(self._model_id) if not self._model: diff --git a/superset/commands/dataset/columns/delete.py b/superset/commands/dataset/columns/delete.py index 4739c2520f880..821528de74d4f 100644 --- a/superset/commands/dataset/columns/delete.py +++ b/superset/commands/dataset/columns/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from superset import security_manager @@ -26,8 +27,8 @@ ) from superset.connectors.sqla.models import TableColumn from superset.daos.dataset import DatasetColumnDAO, DatasetDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.exceptions import SupersetSecurityException +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -38,15 +39,11 @@ def __init__(self, dataset_id: int, model_id: int): self._model_id = model_id self._model: Optional[TableColumn] = None + @transaction(on_error=partial(on_error, reraise=DatasetColumnDeleteFailedError)) def run(self) -> None: self.validate() assert self._model - - try: - DatasetColumnDAO.delete([self._model]) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise DatasetColumnDeleteFailedError() from ex + DatasetColumnDAO.delete([self._model]) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/dataset/create.py b/superset/commands/dataset/create.py index b72c3ff46ebb8..a2d81e548bfb0 100644 --- a/superset/commands/dataset/create.py +++ b/superset/commands/dataset/create.py @@ -15,11 +15,11 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError -from sqlalchemy.exc import SQLAlchemyError from superset.commands.base import BaseCommand, CreateMixin from superset.commands.dataset.exceptions import ( @@ -31,10 +31,10 @@ TableNotFoundValidationError, ) from superset.daos.dataset import DatasetDAO -from superset.daos.exceptions import DAOCreateFailedError from superset.exceptions import SupersetSecurityException -from superset.extensions import db, security_manager +from superset.extensions import security_manager from superset.sql_parse import Table +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -43,19 +43,12 @@ class CreateDatasetCommand(CreateMixin, BaseCommand): def __init__(self, data: dict[str, Any]): self._properties = data.copy() + @transaction(on_error=partial(on_error, reraise=DatasetCreateFailedError)) def run(self) -> Model: self.validate() - try: - # Creates SqlaTable (Dataset) - dataset = DatasetDAO.create(attributes=self._properties, commit=False) - # Updates columns and metrics from the dataset - dataset.fetch_metadata(commit=False) - db.session.commit() - except (SQLAlchemyError, DAOCreateFailedError) as ex: - logger.warning(ex, exc_info=True) - db.session.rollback() - raise DatasetCreateFailedError() from ex + dataset = DatasetDAO.create(attributes=self._properties) + dataset.fetch_metadata() return dataset def validate(self) -> None: diff --git a/superset/commands/dataset/delete.py b/superset/commands/dataset/delete.py index 4b7e61ab4c113..27753062aa767 100644 --- a/superset/commands/dataset/delete.py +++ b/superset/commands/dataset/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from superset import security_manager @@ -26,8 +27,8 @@ ) from superset.connectors.sqla.models import SqlaTable from superset.daos.dataset import DatasetDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.exceptions import SupersetSecurityException +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -37,15 +38,11 @@ def __init__(self, model_ids: list[int]): self._model_ids = model_ids self._models: Optional[list[SqlaTable]] = None + @transaction(on_error=partial(on_error, reraise=DatasetDeleteFailedError)) def run(self) -> None: self.validate() assert self._models - - try: - DatasetDAO.delete(self._models) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise DatasetDeleteFailedError() from ex + DatasetDAO.delete(self._models) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/dataset/duplicate.py b/superset/commands/dataset/duplicate.py index efe4935e60af7..8e82a7662f652 100644 --- a/superset/commands/dataset/duplicate.py +++ b/superset/commands/dataset/duplicate.py @@ -15,12 +15,12 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any from flask_appbuilder.models.sqla import Model from flask_babel import gettext as __ from marshmallow import ValidationError -from sqlalchemy.exc import SQLAlchemyError from superset.commands.base import BaseCommand, CreateMixin from superset.commands.dataset.exceptions import ( @@ -32,12 +32,12 @@ from superset.commands.exceptions import DatasourceTypeInvalidError from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.daos.dataset import DatasetDAO -from superset.daos.exceptions import DAOCreateFailedError from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetErrorException from superset.extensions import db from superset.models.core import Database from superset.sql_parse import ParsedQuery, Table +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -47,66 +47,61 @@ def __init__(self, data: dict[str, Any]) -> None: self._base_model: SqlaTable = SqlaTable() self._properties = data.copy() + @transaction(on_error=partial(on_error, reraise=DatasetDuplicateFailedError)) def run(self) -> Model: self.validate() - try: - database_id = self._base_model.database_id - table_name = self._properties["table_name"] - owners = self._properties["owners"] - database = db.session.query(Database).get(database_id) - if not database: - raise SupersetErrorException( - SupersetError( - message=__("The database was not found."), - error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR, - level=ErrorLevel.ERROR, - ), - status=404, - ) - table = SqlaTable(table_name=table_name, owners=owners) - table.database = database - table.schema = self._base_model.schema - table.template_params = self._base_model.template_params - table.normalize_columns = self._base_model.normalize_columns - table.always_filter_main_dttm = self._base_model.always_filter_main_dttm - table.is_sqllab_view = True - table.sql = ParsedQuery( - self._base_model.sql, - engine=database.db_engine_spec.engine, - ).stripped() - db.session.add(table) - cols = [] - for config_ in self._base_model.columns: - column_name = config_.column_name - col = TableColumn( - column_name=column_name, - verbose_name=config_.verbose_name, - expression=config_.expression, - filterable=True, - groupby=True, - is_dttm=config_.is_dttm, - type=config_.type, - description=config_.description, - ) - cols.append(col) - table.columns = cols - mets = [] - for config_ in self._base_model.metrics: - metric_name = config_.metric_name - met = SqlMetric( - metric_name=metric_name, - verbose_name=config_.verbose_name, - expression=config_.expression, - metric_type=config_.metric_type, - description=config_.description, - ) - mets.append(met) - table.metrics = mets - db.session.commit() - except (SQLAlchemyError, DAOCreateFailedError) as ex: - logger.warning(ex, exc_info=True) - db.session.rollback() - raise DatasetDuplicateFailedError() from ex + database_id = self._base_model.database_id + table_name = self._properties["table_name"] + owners = self._properties["owners"] + database = db.session.query(Database).get(database_id) + if not database: + raise SupersetErrorException( + SupersetError( + message=__("The database was not found."), + error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR, + level=ErrorLevel.ERROR, + ), + status=404, + ) + table = SqlaTable(table_name=table_name, owners=owners) + table.database = database + table.schema = self._base_model.schema + table.template_params = self._base_model.template_params + table.normalize_columns = self._base_model.normalize_columns + table.always_filter_main_dttm = self._base_model.always_filter_main_dttm + table.is_sqllab_view = True + table.sql = ParsedQuery( + self._base_model.sql, + engine=database.db_engine_spec.engine, + ).stripped() + db.session.add(table) + cols = [] + for config_ in self._base_model.columns: + column_name = config_.column_name + col = TableColumn( + column_name=column_name, + verbose_name=config_.verbose_name, + expression=config_.expression, + filterable=True, + groupby=True, + is_dttm=config_.is_dttm, + type=config_.type, + description=config_.description, + ) + cols.append(col) + table.columns = cols + mets = [] + for config_ in self._base_model.metrics: + metric_name = config_.metric_name + met = SqlMetric( + metric_name=metric_name, + verbose_name=config_.verbose_name, + expression=config_.expression, + metric_type=config_.metric_type, + description=config_.description, + ) + mets.append(met) + table.metrics = mets return table def validate(self) -> None: diff --git a/superset/commands/dataset/importers/v0.py b/superset/commands/dataset/importers/v0.py index acfe4a2c9160e..d6f7380cb5d1d 100644 --- a/superset/commands/dataset/importers/v0.py +++ b/superset/commands/dataset/importers/v0.py @@ -34,6 +34,7 @@ ) from superset.models.core import Database from superset.utils import json +from superset.utils.decorators import transaction from superset.utils.dict_import_export import DATABASES_KEY logger = logging.getLogger(__name__) @@ -211,7 +212,6 @@ def import_from_dict(data: dict[str, Any], sync: Optional[list[str]] = None) -> logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY) for database in data.get(DATABASES_KEY, []): Database.import_from_dict(database, sync=sync) - db.session.commit() else: logger.info("Supplied object is not a dictionary.") @@ -240,10 +240,10 @@ def __init__( if kwargs.get("sync_metrics"): self.sync.append("metrics") + @transaction() def run(self) -> None: self.validate() - # TODO (betodealmeida): add rollback in case of error for file_name, config in self._configs.items(): logger.info("Importing dataset from file %s", file_name) if isinstance(config, dict): @@ -260,7 +260,6 @@ def run(self) -> None: ) dataset["database_id"] = database.id SqlaTable.import_from_dict(dataset, sync=self.sync) - db.session.commit() def validate(self) -> None: # ensure all files are YAML diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py index da39be4721c0c..1c508fe2522e8 100644 --- a/superset/commands/dataset/importers/v1/utils.py +++ b/superset/commands/dataset/importers/v1/utils.py @@ -178,7 +178,7 @@ def import_dataset( if data_uri and (not table_exists or force_data): load_data(data_uri, dataset, dataset.database) - if user := get_user(): + if (user := get_user()) and user not in dataset.owners: dataset.owners.append(user) return dataset diff --git a/superset/commands/dataset/metrics/delete.py b/superset/commands/dataset/metrics/delete.py index b48668852cafd..0a749295dc3d6 100644 --- a/superset/commands/dataset/metrics/delete.py +++ b/superset/commands/dataset/metrics/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from superset import security_manager @@ -26,8 +27,8 @@ ) from superset.connectors.sqla.models import SqlMetric from superset.daos.dataset import DatasetDAO, DatasetMetricDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.exceptions import SupersetSecurityException +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -38,15 +39,11 @@ def __init__(self, dataset_id: int, model_id: int): self._model_id = model_id self._model: Optional[SqlMetric] = None + @transaction(on_error=partial(on_error, reraise=DatasetMetricDeleteFailedError)) def run(self) -> None: self.validate() assert self._model - - try: - DatasetMetricDAO.delete([self._model]) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise DatasetMetricDeleteFailedError() from ex + DatasetMetricDAO.delete([self._model]) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/dataset/refresh.py b/superset/commands/dataset/refresh.py index 5976956d7cedf..9605ac866a952 100644 --- a/superset/commands/dataset/refresh.py +++ b/superset/commands/dataset/refresh.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from flask_appbuilder.models.sqla import Model @@ -29,6 +30,7 @@ from superset.connectors.sqla.models import SqlaTable from superset.daos.dataset import DatasetDAO from superset.exceptions import SupersetSecurityException +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -38,16 +40,12 @@ def __init__(self, model_id: int): self._model_id = model_id self._model: Optional[SqlaTable] = None + @transaction(on_error=partial(on_error, reraise=DatasetRefreshFailedError)) def run(self) -> Model: self.validate() - if self._model: - try: - self._model.fetch_metadata() - return self._model - except Exception as ex: - logger.exception(ex) - raise DatasetRefreshFailedError() from ex - raise DatasetRefreshFailedError() + assert self._model + self._model.fetch_metadata() + return self._model def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/dataset/update.py b/superset/commands/dataset/update.py index 2b521452436eb..14d1c5ef44707 100644 --- a/superset/commands/dataset/update.py +++ b/superset/commands/dataset/update.py @@ -16,10 +16,12 @@ # under the License. import logging from collections import Counter +from functools import partial from typing import Any, Optional from flask_appbuilder.models.sqla import Model from marshmallow import ValidationError +from sqlalchemy.exc import SQLAlchemyError from superset import security_manager from superset.commands.base import BaseCommand, UpdateMixin @@ -39,9 +41,9 @@ ) from superset.connectors.sqla.models import SqlaTable from superset.daos.dataset import DatasetDAO -from superset.daos.exceptions import DAOUpdateFailedError from superset.exceptions import SupersetSecurityException from superset.sql_parse import Table +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -59,19 +61,20 @@ def __init__( self.override_columns = override_columns self._properties["override_columns"] = override_columns + @transaction( + on_error=partial( + on_error, + catches=( + SQLAlchemyError, + ValueError, + ), + reraise=DatasetUpdateFailedError, + ) + ) def run(self) -> Model: self.validate() - if self._model: - try: - dataset = DatasetDAO.update( - self._model, - attributes=self._properties, - ) - return dataset - except DAOUpdateFailedError as ex: - logger.exception(ex.exception) - raise DatasetUpdateFailedError() from ex - raise DatasetUpdateFailedError() + assert self._model + return DatasetDAO.update(self._model, attributes=self._properties) def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/commands/explore/permalink/create.py b/superset/commands/explore/permalink/create.py index 731e0b5ce8a02..2128fa4b8c40e 100644 --- a/superset/commands/explore/permalink/create.py +++ b/superset/commands/explore/permalink/create.py @@ -15,18 +15,22 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any, Optional from sqlalchemy.exc import SQLAlchemyError -from superset import db from superset.commands.explore.permalink.base import BaseExplorePermalinkCommand from superset.commands.key_value.create import CreateKeyValueCommand from superset.explore.permalink.exceptions import ExplorePermalinkCreateFailedError from superset.explore.utils import check_access as check_chart_access -from superset.key_value.exceptions import KeyValueCodecEncodeException +from superset.key_value.exceptions import ( + KeyValueCodecEncodeException, + KeyValueCreateFailedError, +) from superset.key_value.utils import encode_permalink_key from superset.utils.core import DatasourceType +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -37,35 +41,39 @@ def __init__(self, state: dict[str, Any]): self.datasource: str = state["formData"]["datasource"] self.state = state + @transaction( + on_error=partial( + on_error, + catches=( + KeyValueCodecEncodeException, + KeyValueCreateFailedError, + SQLAlchemyError, + ), + reraise=ExplorePermalinkCreateFailedError, + ), + ) def run(self) -> str: self.validate() - try: - d_id, d_type = self.datasource.split("__") - datasource_id = int(d_id) - datasource_type = DatasourceType(d_type) - check_chart_access(datasource_id, self.chart_id, datasource_type) - value = { - "chartId": self.chart_id, - "datasourceId": datasource_id, - "datasourceType": datasource_type.value, - "datasource": self.datasource, - "state": self.state, - } - command = CreateKeyValueCommand( - resource=self.resource, - value=value, - codec=self.codec, - ) - key = command.run() - if key.id is None: - raise ExplorePermalinkCreateFailedError("Unexpected missing key id") - db.session.commit() - return encode_permalink_key(key=key.id, salt=self.salt) - except KeyValueCodecEncodeException as ex: - raise ExplorePermalinkCreateFailedError(str(ex)) from ex - except SQLAlchemyError as ex: - logger.exception("Error running create command") - raise ExplorePermalinkCreateFailedError() from ex + d_id, d_type = self.datasource.split("__") + datasource_id = int(d_id) + datasource_type = DatasourceType(d_type) + check_chart_access(datasource_id, self.chart_id, datasource_type) + value = { + "chartId": self.chart_id, + "datasourceId": datasource_id, + "datasourceType": datasource_type.value, + "datasource": self.datasource, + "state": self.state, + } + command = CreateKeyValueCommand( + resource=self.resource, + value=value, + codec=self.codec, + ) + key = command.run() + if key.id is None: + raise ExplorePermalinkCreateFailedError("Unexpected missing key id") + return encode_permalink_key(key=key.id, salt=self.salt) def validate(self) -> None: pass diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py index 25b8b8790f046..f90708acf51f1 100644 --- a/superset/commands/importers/v1/__init__.py +++ b/superset/commands/importers/v1/__init__.py @@ -32,6 +32,7 @@ ) from superset.daos.base import BaseDAO from superset.models.core import Database # noqa: F401 +from superset.utils.decorators import transaction class ImportModelsCommand(BaseCommand): @@ -67,18 +68,15 @@ def _import(configs: dict[str, Any], overwrite: bool = False) -> None: def _get_uuids(cls) -> set[str]: return {str(model.uuid) for model in db.session.query(cls.dao.model_cls).all()} + @transaction() def run(self) -> None: self.validate() - # rollback to prevent partial imports try: self._import(self._configs, self.overwrite) - db.session.commit() except CommandException: - db.session.rollback() raise except Exception as ex: - db.session.rollback() raise self.import_error() from ex def validate(self) -> None: # noqa: F811 diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py index 29a2dec179084..78a2251a293af 100644 --- a/superset/commands/importers/v1/assets.py +++ b/superset/commands/importers/v1/assets.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from functools import partial from typing import Any, Optional from marshmallow import Schema @@ -44,6 +45,7 @@ from superset.migrations.shared.native_filters import migrate_dashboard from superset.models.dashboard import dashboard_slices from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema +from superset.utils.decorators import on_error, transaction class ImportAssetsCommand(BaseCommand): @@ -153,16 +155,16 @@ def _import(configs: dict[str, Any]) -> None: if chart.viz_type == "filter_box": db.session.delete(chart) + @transaction( + on_error=partial( + on_error, + catches=(Exception,), + reraise=ImportFailedError, + ) + ) def run(self) -> None: self.validate() - - # rollback to prevent partial imports - try: - self._import(self._configs) - db.session.commit() - except Exception as ex: - db.session.rollback() - raise ImportFailedError() from ex + self._import(self._configs) def validate(self) -> None: exceptions: list[ValidationError] = [] diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py index 6525031ce4f36..bcf6b5062fb9b 100644 --- a/superset/commands/importers/v1/examples.py +++ b/superset/commands/importers/v1/examples.py @@ -43,6 +43,7 @@ from superset.models.dashboard import dashboard_slices from superset.utils.core import get_example_default_schema from superset.utils.database import get_example_database +from superset.utils.decorators import transaction class ImportExamplesCommand(ImportModelsCommand): @@ -62,19 +63,17 @@ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): super().__init__(contents, *args, **kwargs) self.force_data = kwargs.get("force_data", False) + @transaction() def run(self) -> None: self.validate() - # rollback to prevent partial imports try: self._import( self._configs, self.overwrite, self.force_data, ) - db.session.commit() except Exception as ex: - db.session.rollback() raise self.import_error() from ex @classmethod diff --git a/superset/commands/key_value/create.py b/superset/commands/key_value/create.py index 7308321e44a73..81b7c4c3d4a93 100644 --- a/superset/commands/key_value/create.py +++ b/superset/commands/key_value/create.py @@ -16,17 +16,17 @@ # under the License. import logging from datetime import datetime +from functools import partial from typing import Any, Optional, Union from uuid import UUID -from sqlalchemy.exc import SQLAlchemyError - from superset import db from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.models import KeyValueEntry from superset.key_value.types import Key, KeyValueCodec, KeyValueResource from superset.utils.core import get_user_id +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -62,6 +62,7 @@ def __init__( # pylint: disable=too-many-arguments self.key = key self.expires_on = expires_on + @transaction(on_error=partial(on_error, reraise=KeyValueCreateFailedError)) def run(self) -> Key: """ Persist the value @@ -69,11 +70,8 @@ def run(self) -> Key: :return: the key associated with the persisted value """ - try: - return self.create() - except SQLAlchemyError as ex: - db.session.rollback() - raise KeyValueCreateFailedError() from ex + + return self.create() def validate(self) -> None: pass @@ -98,6 +96,7 @@ def create(self) -> Key: entry.id = self.key except ValueError as ex: raise KeyValueCreateFailedError() from ex + db.session.add(entry) db.session.flush() return Key(id=entry.id, uuid=entry.uuid) diff --git a/superset/commands/key_value/delete.py b/superset/commands/key_value/delete.py index 37eb7087e6a20..a3fdf079c73c2 100644 --- a/superset/commands/key_value/delete.py +++ b/superset/commands/key_value/delete.py @@ -15,17 +15,17 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Union from uuid import UUID -from sqlalchemy.exc import SQLAlchemyError - from superset import db from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueDeleteFailedError from superset.key_value.models import KeyValueEntry from superset.key_value.types import KeyValueResource from superset.key_value.utils import get_filter +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -45,20 +45,19 @@ def __init__(self, resource: KeyValueResource, key: Union[int, UUID]): self.resource = resource self.key = key + @transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError)) def run(self) -> bool: - try: - return self.delete() - except SQLAlchemyError as ex: - db.session.rollback() - raise KeyValueDeleteFailedError() from ex + return self.delete() def validate(self) -> None: pass def delete(self) -> bool: - filter_ = get_filter(self.resource, self.key) - if entry := db.session.query(KeyValueEntry).filter_by(**filter_).first(): + if ( + entry := db.session.query(KeyValueEntry) + .filter_by(**get_filter(self.resource, self.key)) + .first() + ): db.session.delete(entry) - db.session.flush() return True return False diff --git a/superset/commands/key_value/delete_expired.py b/superset/commands/key_value/delete_expired.py index 92d45683f222e..54991c7531d27 100644 --- a/superset/commands/key_value/delete_expired.py +++ b/superset/commands/key_value/delete_expired.py @@ -16,15 +16,16 @@ # under the License. import logging from datetime import datetime +from functools import partial from sqlalchemy import and_ -from sqlalchemy.exc import SQLAlchemyError from superset import db from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueDeleteFailedError from superset.key_value.models import KeyValueEntry from superset.key_value.types import KeyValueResource +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -41,12 +42,9 @@ def __init__(self, resource: KeyValueResource): """ self.resource = resource + @transaction(on_error=partial(on_error, reraise=KeyValueDeleteFailedError)) def run(self) -> None: - try: - self.delete_expired() - except SQLAlchemyError as ex: - db.session.rollback() - raise KeyValueDeleteFailedError() from ex + self.delete_expired() def validate(self) -> None: pass @@ -62,4 +60,3 @@ def delete_expired(self) -> None: ) .delete() ) - db.session.flush() diff --git a/superset/commands/key_value/update.py b/superset/commands/key_value/update.py index 098c9f860d1b6..b6ffc22174f60 100644 --- a/superset/commands/key_value/update.py +++ b/superset/commands/key_value/update.py @@ -17,11 +17,10 @@ import logging from datetime import datetime +from functools import partial from typing import Any, Optional, Union from uuid import UUID -from sqlalchemy.exc import SQLAlchemyError - from superset import db from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueUpdateFailedError @@ -29,6 +28,7 @@ from superset.key_value.types import Key, KeyValueCodec, KeyValueResource from superset.key_value.utils import get_filter from superset.utils.core import get_user_id +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -64,12 +64,9 @@ def __init__( # pylint: disable=too-many-arguments self.codec = codec self.expires_on = expires_on + @transaction(on_error=partial(on_error, reraise=KeyValueUpdateFailedError)) def run(self) -> Optional[Key]: - try: - return self.update() - except SQLAlchemyError as ex: - db.session.rollback() - raise KeyValueUpdateFailedError() from ex + return self.update() def validate(self) -> None: pass diff --git a/superset/commands/key_value/upsert.py b/superset/commands/key_value/upsert.py index 2c985530bf20a..32918d9b14396 100644 --- a/superset/commands/key_value/upsert.py +++ b/superset/commands/key_value/upsert.py @@ -17,6 +17,7 @@ import logging from datetime import datetime +from functools import partial from typing import Any, Optional, Union from uuid import UUID @@ -33,6 +34,7 @@ from superset.key_value.types import Key, KeyValueCodec, KeyValueResource from superset.key_value.utils import get_filter from superset.utils.core import get_user_id +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -68,27 +70,29 @@ def __init__( # pylint: disable=too-many-arguments self.codec = codec self.expires_on = expires_on + @transaction( + on_error=partial( + on_error, + catches=(KeyValueCreateFailedError, SQLAlchemyError), + reraise=KeyValueUpsertFailedError, + ), + ) def run(self) -> Key: - try: - return self.upsert() - except (KeyValueCreateFailedError, SQLAlchemyError) as ex: - db.session.rollback() - raise KeyValueUpsertFailedError() from ex + return self.upsert() def validate(self) -> None: pass def upsert(self) -> Key: - filter_ = get_filter(self.resource, self.key) - entry: KeyValueEntry = ( - db.session.query(KeyValueEntry).filter_by(**filter_).first() - ) - if entry: + if ( + entry := db.session.query(KeyValueEntry) + .filter_by(**get_filter(self.resource, self.key)) + .first() + ): entry.value = self.codec.encode(self.value) entry.expires_on = self.expires_on entry.changed_on = datetime.now() entry.changed_by_fk = get_user_id() - db.session.flush() return Key(entry.id, entry.uuid) return CreateKeyValueCommand( diff --git a/superset/commands/query/delete.py b/superset/commands/query/delete.py index 978f30c5c4a87..a93c4038abf46 100644 --- a/superset/commands/query/delete.py +++ b/superset/commands/query/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from superset.commands.base import BaseCommand @@ -22,9 +23,9 @@ SavedQueryDeleteFailedError, SavedQueryNotFoundError, ) -from superset.daos.exceptions import DAODeleteFailedError from superset.daos.query import SavedQueryDAO from superset.models.dashboard import Dashboard +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -34,15 +35,11 @@ def __init__(self, model_ids: list[int]): self._model_ids = model_ids self._models: Optional[list[Dashboard]] = None + @transaction(on_error=partial(on_error, reraise=SavedQueryDeleteFailedError)) def run(self) -> None: self.validate() assert self._models - - try: - SavedQueryDAO.delete(self._models) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise SavedQueryDeleteFailedError() from ex + SavedQueryDAO.delete(self._models) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/report/create.py b/superset/commands/report/create.py index ed1889e8b3321..2a67f640022d2 100644 --- a/superset/commands/report/create.py +++ b/superset/commands/report/create.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any, Optional from flask_babel import gettext as _ @@ -31,7 +32,6 @@ ReportScheduleNameUniquenessValidationError, ) from superset.daos.database import DatabaseDAO -from superset.daos.exceptions import DAOCreateFailedError from superset.daos.report import ReportScheduleDAO from superset.reports.models import ( ReportCreationMethod, @@ -40,6 +40,7 @@ ) from superset.reports.types import ReportScheduleExtra from superset.utils import json +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -48,13 +49,10 @@ class CreateReportScheduleCommand(CreateMixin, BaseReportScheduleCommand): def __init__(self, data: dict[str, Any]): self._properties = data.copy() + @transaction(on_error=partial(on_error, reraise=ReportScheduleCreateFailedError)) def run(self) -> ReportSchedule: self.validate() - try: - return ReportScheduleDAO.create(attributes=self._properties) - except DAOCreateFailedError as ex: - logger.exception(ex.exception) - raise ReportScheduleCreateFailedError() from ex + return ReportScheduleDAO.create(attributes=self._properties) def validate(self) -> None: """ diff --git a/superset/commands/report/delete.py b/superset/commands/report/delete.py index 87ea4b99dd017..36e6711105c82 100644 --- a/superset/commands/report/delete.py +++ b/superset/commands/report/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Optional from superset import security_manager @@ -24,10 +25,10 @@ ReportScheduleForbiddenError, ReportScheduleNotFoundError, ) -from superset.daos.exceptions import DAODeleteFailedError from superset.daos.report import ReportScheduleDAO from superset.exceptions import SupersetSecurityException from superset.reports.models import ReportSchedule +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -37,15 +38,11 @@ def __init__(self, model_ids: list[int]): self._model_ids = model_ids self._models: Optional[list[ReportSchedule]] = None + @transaction(on_error=partial(on_error, reraise=ReportScheduleDeleteFailedError)) def run(self) -> None: self.validate() assert self._models - - try: - ReportScheduleDAO.delete(self._models) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise ReportScheduleDeleteFailedError() from ex + ReportScheduleDAO.delete(self._models) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/report/execute.py b/superset/commands/report/execute.py index 637898a7a0a50..c57828eac497b 100644 --- a/superset/commands/report/execute.py +++ b/superset/commands/report/execute.py @@ -69,7 +69,7 @@ from superset.utils import json from superset.utils.core import HeaderDataType, override_user from superset.utils.csv import get_chart_csv_data, get_chart_dataframe -from superset.utils.decorators import logs_context +from superset.utils.decorators import logs_context, transaction from superset.utils.pdf import build_pdf_from_screenshots from superset.utils.screenshots import ChartScreenshot, DashboardScreenshot from superset.utils.urls import get_url_path @@ -120,7 +120,6 @@ def update_report_schedule(self, state: ReportState) -> None: self._report_schedule.last_state = state self._report_schedule.last_eval_dttm = datetime.utcnow() - db.session.commit() def create_log(self, error_message: Optional[str] = None) -> None: """ @@ -138,7 +137,7 @@ def create_log(self, error_message: Optional[str] = None) -> None: uuid=self._execution_id, ) db.session.add(log) - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction def _get_url( self, @@ -690,6 +689,7 @@ def __init__( self._report_schedule = report_schedule self._scheduled_dttm = scheduled_dttm + @transaction() def run(self) -> None: for state_cls in self.states_cls: if (self._report_schedule.last_state is None and state_cls.initial) or ( @@ -718,6 +718,7 @@ def __init__(self, task_id: str, model_id: int, scheduled_dttm: datetime): self._scheduled_dttm = scheduled_dttm self._execution_id = UUID(task_id) + @transaction() def run(self) -> None: try: self.validate() diff --git a/superset/commands/report/log_prune.py b/superset/commands/report/log_prune.py index f14f7856a1e15..a780bf51e0333 100644 --- a/superset/commands/report/log_prune.py +++ b/superset/commands/report/log_prune.py @@ -17,12 +17,14 @@ import logging from datetime import datetime, timedelta +from sqlalchemy.exc import SQLAlchemyError + from superset import db from superset.commands.base import BaseCommand from superset.commands.report.exceptions import ReportSchedulePruneLogError -from superset.daos.exceptions import DAODeleteFailedError from superset.daos.report import ReportScheduleDAO from superset.reports.models import ReportSchedule +from superset.utils.decorators import transaction logger = logging.getLogger(__name__) @@ -32,9 +34,7 @@ class AsyncPruneReportScheduleLogCommand(BaseCommand): Prunes logs from all report schedules """ - def __init__(self, worker_context: bool = True): - self._worker_context = worker_context - + @transaction() def run(self) -> None: self.validate() prune_errors = [] @@ -46,15 +46,15 @@ def run(self) -> None: ) try: row_count = ReportScheduleDAO.bulk_delete_logs( - report_schedule, from_date, commit=False + report_schedule, + from_date, ) - db.session.commit() logger.info( "Deleted %s logs for report schedule id: %s", str(row_count), str(report_schedule.id), ) - except DAODeleteFailedError as ex: + except SQLAlchemyError as ex: prune_errors.append(str(ex)) if prune_errors: raise ReportSchedulePruneLogError(";".join(prune_errors)) diff --git a/superset/commands/report/update.py b/superset/commands/report/update.py index ad54f44f0618d..2aab3bd8c4520 100644 --- a/superset/commands/report/update.py +++ b/superset/commands/report/update.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any, Optional from flask_appbuilder.models.sqla import Model @@ -32,11 +33,11 @@ ReportScheduleUpdateFailedError, ) from superset.daos.database import DatabaseDAO -from superset.daos.exceptions import DAOUpdateFailedError from superset.daos.report import ReportScheduleDAO from superset.exceptions import SupersetSecurityException from superset.reports.models import ReportSchedule, ReportScheduleType, ReportState from superset.utils import json +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -47,16 +48,10 @@ def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model: Optional[ReportSchedule] = None + @transaction(on_error=partial(on_error, reraise=ReportScheduleUpdateFailedError)) def run(self) -> Model: self.validate() - assert self._model - - try: - report_schedule = ReportScheduleDAO.update(self._model, self._properties) - except DAOUpdateFailedError as ex: - logger.exception(ex.exception) - raise ReportScheduleUpdateFailedError() from ex - return report_schedule + return ReportScheduleDAO.update(self._model, self._properties) def validate(self) -> None: """ diff --git a/superset/commands/security/create.py b/superset/commands/security/create.py index d70bbb7111a87..0288cf4d0b90e 100644 --- a/superset/commands/security/create.py +++ b/superset/commands/security/create.py @@ -23,9 +23,9 @@ from superset.commands.exceptions import DatasourceNotFoundValidationError from superset.commands.utils import populate_roles from superset.connectors.sqla.models import SqlaTable -from superset.daos.exceptions import DAOCreateFailedError from superset.daos.security import RLSDAO from superset.extensions import db +from superset.utils.decorators import transaction logger = logging.getLogger(__name__) @@ -36,13 +36,10 @@ def __init__(self, data: dict[str, Any]): self._tables = self._properties.get("tables", []) self._roles = self._properties.get("roles", []) + @transaction() def run(self) -> Any: self.validate() - try: - return RLSDAO.create(attributes=self._properties) - except DAOCreateFailedError as ex: - logger.exception(ex.exception) - raise + return RLSDAO.create(attributes=self._properties) def validate(self) -> None: roles = populate_roles(self._roles) diff --git a/superset/commands/security/delete.py b/superset/commands/security/delete.py index 2c19c5f89b78c..662474c27edc1 100644 --- a/superset/commands/security/delete.py +++ b/superset/commands/security/delete.py @@ -16,15 +16,16 @@ # under the License. import logging +from functools import partial from superset.commands.base import BaseCommand from superset.commands.security.exceptions import ( RLSRuleNotFoundError, RuleDeleteFailedError, ) -from superset.daos.exceptions import DAODeleteFailedError from superset.daos.security import RLSDAO from superset.reports.models import ReportSchedule +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -34,13 +35,10 @@ def __init__(self, model_ids: list[int]): self._model_ids = model_ids self._models: list[ReportSchedule] = [] + @transaction(on_error=partial(on_error, reraise=RuleDeleteFailedError)) def run(self) -> None: self.validate() - try: - RLSDAO.delete(self._models) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise RuleDeleteFailedError() from ex + RLSDAO.delete(self._models) def validate(self) -> None: # Validate/populate model exists diff --git a/superset/commands/security/update.py b/superset/commands/security/update.py index 54d7a66a2a238..fa17b249b47b8 100644 --- a/superset/commands/security/update.py +++ b/superset/commands/security/update.py @@ -24,9 +24,9 @@ from superset.commands.security.exceptions import RLSRuleNotFoundError from superset.commands.utils import populate_roles from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable -from superset.daos.exceptions import DAOUpdateFailedError from superset.daos.security import RLSDAO from superset.extensions import db +from superset.utils.decorators import transaction logger = logging.getLogger(__name__) @@ -39,17 +39,11 @@ def __init__(self, model_id: int, data: dict[str, Any]): self._roles = self._properties.get("roles", []) self._model: Optional[RowLevelSecurityFilter] = None + @transaction() def run(self) -> Any: self.validate() assert self._model - - try: - rule = RLSDAO.update(self._model, self._properties) - except DAOUpdateFailedError as ex: - logger.exception(ex.exception) - raise - - return rule + return RLSDAO.update(self._model, self._properties) def validate(self) -> None: self._model = RLSDAO.find_by_id(int(self._model_id)) diff --git a/superset/commands/sql_lab/execute.py b/superset/commands/sql_lab/execute.py index 911424af51ce4..0c3e33b529169 100644 --- a/superset/commands/sql_lab/execute.py +++ b/superset/commands/sql_lab/execute.py @@ -22,10 +22,11 @@ from typing import Any, TYPE_CHECKING from flask_babel import gettext as __ +from sqlalchemy.exc import SQLAlchemyError +from superset import db from superset.commands.base import BaseCommand from superset.common.db_query_status import QueryStatus -from superset.daos.exceptions import DAOCreateFailedError from superset.errors import SupersetErrorType from superset.exceptions import ( SupersetErrorException, @@ -41,6 +42,7 @@ ) from superset.sqllab.execution_context_convertor import ExecutionContextConvertor from superset.sqllab.limiting_factor import LimitingFactor +from superset.utils.decorators import transaction if TYPE_CHECKING: from superset.daos.database import DatabaseDAO @@ -90,6 +92,7 @@ def __init__( def validate(self) -> None: pass + @transaction() def run( # pylint: disable=too-many-statements,useless-suppression self, ) -> CommandResult: @@ -178,9 +181,22 @@ def _validate_query_db(cls, database: Database | None) -> None: ) def _save_new_query(self, query: Query) -> None: + """ + Saves the new SQL Lab query. + + Committing within a transaction violates the "unit of work" construct, but is + necessary for async querying. The Celery task is defined within the confines + of another command and needs to read a previously committed state given the + `READ COMMITTED` isolation level. + + To mitigate said issue, ideally there would be a command to prepare said query + and another to execute it, either in a sync or async manner. + + :param query: The SQL Lab query + """ try: self._query_dao.create(query) - except DAOCreateFailedError as ex: + except SQLAlchemyError as ex: raise SqlLabException( self._execution_context, SupersetErrorType.GENERIC_DB_ENGINE_ERROR, @@ -189,6 +205,8 @@ def _save_new_query(self, query: Query) -> None: "Please contact an administrator for further assistance or try again.", ) from ex + db.session.commit() # pylint: disable=consider-using-transaction + def _validate_access(self, query: Query) -> None: try: self._access_validator.validate(query) diff --git a/superset/commands/tag/create.py b/superset/commands/tag/create.py index ea23b8d59da10..775250dc8172d 100644 --- a/superset/commands/tag/create.py +++ b/superset/commands/tag/create.py @@ -15,16 +15,17 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from typing import Any -from superset import db, security_manager +from superset import security_manager from superset.commands.base import BaseCommand, CreateMixin from superset.commands.tag.exceptions import TagCreateFailedError, TagInvalidError from superset.commands.tag.utils import to_object_model, to_object_type -from superset.daos.exceptions import DAOCreateFailedError from superset.daos.tag import TagDAO from superset.exceptions import SupersetSecurityException from superset.tags.models import ObjectType, TagType +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -35,20 +36,18 @@ def __init__(self, object_type: ObjectType, object_id: int, tags: list[str]): self._object_id = object_id self._tags = tags + @transaction(on_error=partial(on_error, reraise=TagCreateFailedError)) def run(self) -> None: self.validate() - try: - object_type = to_object_type(self._object_type) - if object_type is None: - raise TagCreateFailedError(f"invalid object type {self._object_type}") - TagDAO.create_custom_tagged_objects( - object_type=object_type, - object_id=self._object_id, - tag_names=self._tags, - ) - except DAOCreateFailedError as ex: - logger.exception(ex.exception) - raise TagCreateFailedError() from ex + object_type = to_object_type(self._object_type) + if object_type is None: + raise TagCreateFailedError(f"invalid object type {self._object_type}") + + TagDAO.create_custom_tagged_objects( + object_type=object_type, + object_id=self._object_id, + tag_names=self._tags, + ) def validate(self) -> None: exceptions = [] @@ -71,27 +70,20 @@ def __init__(self, data: dict[str, Any], bulk_create: bool = False): self._bulk_create = bulk_create self._skipped_tagged_objects: set[tuple[str, int]] = set() + @transaction(on_error=partial(on_error, reraise=TagCreateFailedError)) def run(self) -> tuple[set[tuple[str, int]], set[tuple[str, int]]]: self.validate() - try: - tag_name = self._properties["name"] - tag = TagDAO.get_by_name(tag_name.strip(), TagType.custom) - TagDAO.create_tag_relationship( - objects_to_tag=self._properties.get("objects_to_tag", []), - tag=tag, - bulk_create=self._bulk_create, - ) - - tag.description = self._properties.get("description", "") - - db.session.commit() - - return set(self._properties["objects_to_tag"]), self._skipped_tagged_objects + tag_name = self._properties["name"] + tag = TagDAO.get_by_name(tag_name.strip(), TagType.custom) + TagDAO.create_tag_relationship( + objects_to_tag=self._properties.get("objects_to_tag", []), + tag=tag, + bulk_create=self._bulk_create, + ) - except DAOCreateFailedError as ex: - logger.exception(ex.exception) - raise TagCreateFailedError() from ex + tag.description = self._properties.get("description", "") + return set(self._properties["objects_to_tag"]), self._skipped_tagged_objects def validate(self) -> None: exceptions = [] diff --git a/superset/commands/tag/delete.py b/superset/commands/tag/delete.py index c4f22390095dc..89a2a5a5568d8 100644 --- a/superset/commands/tag/delete.py +++ b/superset/commands/tag/delete.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import logging +from functools import partial from superset.commands.base import BaseCommand from superset.commands.tag.exceptions import ( @@ -25,9 +26,9 @@ TagNotFoundError, ) from superset.commands.tag.utils import to_object_type -from superset.daos.exceptions import DAODeleteFailedError from superset.daos.tag import TagDAO from superset.tags.models import ObjectType +from superset.utils.decorators import on_error, transaction from superset.views.base import DeleteMixin logger = logging.getLogger(__name__) @@ -39,18 +40,15 @@ def __init__(self, object_type: ObjectType, object_id: int, tag: str): self._object_id = object_id self._tag = tag + @transaction(on_error=partial(on_error, reraise=TaggedObjectDeleteFailedError)) def run(self) -> None: self.validate() - try: - object_type = to_object_type(self._object_type) - if object_type is None: - raise TaggedObjectDeleteFailedError( - f"invalid object type {self._object_type}" - ) - TagDAO.delete_tagged_object(object_type, self._object_id, self._tag) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise TaggedObjectDeleteFailedError() from ex + object_type = to_object_type(self._object_type) + if object_type is None: + raise TaggedObjectDeleteFailedError( + f"invalid object type {self._object_type}" + ) + TagDAO.delete_tagged_object(object_type, self._object_id, self._tag) def validate(self) -> None: exceptions = [] @@ -92,13 +90,10 @@ class DeleteTagsCommand(DeleteMixin, BaseCommand): def __init__(self, tags: list[str]): self._tags = tags + @transaction(on_error=partial(on_error, reraise=TagDeleteFailedError)) def run(self) -> None: self.validate() - try: - TagDAO.delete_tags(self._tags) - except DAODeleteFailedError as ex: - logger.exception(ex.exception) - raise TagDeleteFailedError() from ex + TagDAO.delete_tags(self._tags) def validate(self) -> None: exceptions = [] diff --git a/superset/commands/tag/update.py b/superset/commands/tag/update.py index 431bf93c4de8c..fa5e125414cfb 100644 --- a/superset/commands/tag/update.py +++ b/superset/commands/tag/update.py @@ -25,6 +25,7 @@ from superset.commands.tag.utils import to_object_type from superset.daos.tag import TagDAO from superset.tags.models import Tag +from superset.utils.decorators import transaction logger = logging.getLogger(__name__) @@ -35,18 +36,17 @@ def __init__(self, model_id: int, data: dict[str, Any]): self._properties = data.copy() self._model: Optional[Tag] = None + @transaction() def run(self) -> Model: self.validate() - if self._model: - self._model.name = self._properties["name"] - TagDAO.create_tag_relationship( - objects_to_tag=self._properties.get("objects_to_tag", []), - tag=self._model, - ) - self._model.description = self._properties.get("description") - - db.session.add(self._model) - db.session.commit() + assert self._model + self._model.name = self._properties["name"] + TagDAO.create_tag_relationship( + objects_to_tag=self._properties.get("objects_to_tag", []), + tag=self._model, + ) + self._model.description = self._properties.get("description") + db.session.add(self._model) return self._model diff --git a/superset/commands/temporary_cache/create.py b/superset/commands/temporary_cache/create.py index 7d61a78074625..642e812d02065 100644 --- a/superset/commands/temporary_cache/create.py +++ b/superset/commands/temporary_cache/create.py @@ -16,12 +16,12 @@ # under the License. import logging from abc import ABC, abstractmethod - -from sqlalchemy.exc import SQLAlchemyError +from functools import partial from superset.commands.base import BaseCommand from superset.commands.temporary_cache.exceptions import TemporaryCacheCreateFailedError from superset.commands.temporary_cache.parameters import CommandParameters +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -30,12 +30,9 @@ class CreateTemporaryCacheCommand(BaseCommand, ABC): def __init__(self, cmd_params: CommandParameters): self._cmd_params = cmd_params + @transaction(on_error=partial(on_error, reraise=TemporaryCacheCreateFailedError)) def run(self) -> str: - try: - return self.create(self._cmd_params) - except SQLAlchemyError as ex: - logger.exception("Error running create command") - raise TemporaryCacheCreateFailedError() from ex + return self.create(self._cmd_params) def validate(self) -> None: pass diff --git a/superset/commands/temporary_cache/delete.py b/superset/commands/temporary_cache/delete.py index 1cc291dbf6e2e..25cc25ec7ad4f 100644 --- a/superset/commands/temporary_cache/delete.py +++ b/superset/commands/temporary_cache/delete.py @@ -16,12 +16,12 @@ # under the License. import logging from abc import ABC, abstractmethod - -from sqlalchemy.exc import SQLAlchemyError +from functools import partial from superset.commands.base import BaseCommand from superset.commands.temporary_cache.exceptions import TemporaryCacheDeleteFailedError from superset.commands.temporary_cache.parameters import CommandParameters +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -30,12 +30,9 @@ class DeleteTemporaryCacheCommand(BaseCommand, ABC): def __init__(self, cmd_params: CommandParameters): self._cmd_params = cmd_params + @transaction(on_error=partial(on_error, reraise=TemporaryCacheDeleteFailedError)) def run(self) -> bool: - try: - return self.delete(self._cmd_params) - except SQLAlchemyError as ex: - logger.exception("Error running delete command") - raise TemporaryCacheDeleteFailedError() from ex + return self.delete(self._cmd_params) def validate(self) -> None: pass diff --git a/superset/commands/temporary_cache/update.py b/superset/commands/temporary_cache/update.py index 8daaae8618bab..88bbe18b852c1 100644 --- a/superset/commands/temporary_cache/update.py +++ b/superset/commands/temporary_cache/update.py @@ -16,13 +16,13 @@ # under the License. import logging from abc import ABC, abstractmethod +from functools import partial from typing import Optional -from sqlalchemy.exc import SQLAlchemyError - from superset.commands.base import BaseCommand from superset.commands.temporary_cache.exceptions import TemporaryCacheUpdateFailedError from superset.commands.temporary_cache.parameters import CommandParameters +from superset.utils.decorators import on_error, transaction logger = logging.getLogger(__name__) @@ -34,12 +34,9 @@ def __init__( ): self._parameters = cmd_params + @transaction(on_error=partial(on_error, reraise=TemporaryCacheUpdateFailedError)) def run(self) -> Optional[str]: - try: - return self.update(self._parameters) - except SQLAlchemyError as ex: - logger.exception("Error running update command") - raise TemporaryCacheUpdateFailedError() from ex + return self.update(self._parameters) def validate(self) -> None: pass diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 6d8d87a506c00..c38a0085a534b 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1768,11 +1768,10 @@ def get_sqla_table_object(self) -> Table: ) ) - def fetch_metadata(self, commit: bool = True) -> MetadataResult: + def fetch_metadata(self) -> MetadataResult: """ Fetches the metadata for the table and merges it in - :param commit: should the changes be committed or not. :return: Tuple with lists of added, removed and modified column names. """ new_columns = self.external_metadata() @@ -1850,8 +1849,6 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult: config["SQLA_TABLE_MUTATOR"](self) db.session.merge(self) - if commit: - db.session.commit() return results @classmethod diff --git a/superset/daos/base.py b/superset/daos/base.py index 889a0780f642f..e393034062b87 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -21,13 +21,8 @@ from flask_appbuilder.models.filters import BaseFilter from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla.interface import SQLAInterface -from sqlalchemy.exc import SQLAlchemyError, StatementError +from sqlalchemy.exc import StatementError -from superset.daos.exceptions import ( - DAOCreateFailedError, - DAODeleteFailedError, - DAOUpdateFailedError, -) from superset.extensions import db T = TypeVar("T", bound=Model) @@ -127,15 +122,12 @@ def create( cls, item: T | None = None, attributes: dict[str, Any] | None = None, - commit: bool = True, ) -> T: """ Create an object from the specified item and/or attributes. :param item: The object to create :param attributes: The attributes associated with the object to create - :param commit: Whether to commit the transaction - :raises DAOCreateFailedError: If the creation failed """ if not item: @@ -145,15 +137,7 @@ def create( for key, value in attributes.items(): setattr(item, key, value) - try: - db.session.add(item) - - if commit: - db.session.commit() - except SQLAlchemyError as ex: # pragma: no cover - db.session.rollback() - raise DAOCreateFailedError(exception=ex) from ex - + db.session.add(item) return item # type: ignore @classmethod @@ -161,15 +145,12 @@ def update( cls, item: T | None = None, attributes: dict[str, Any] | None = None, - commit: bool = True, ) -> T: """ Update an object from the specified item and/or attributes. :param item: The object to update :param attributes: The attributes associated with the object to update - :param commit: Whether to commit the transaction - :raises DAOUpdateFailedError: If the updating failed """ if not item: @@ -179,19 +160,13 @@ def update( for key, value in attributes.items(): setattr(item, key, value) - try: - db.session.merge(item) - - if commit: - db.session.commit() - except SQLAlchemyError as ex: # pragma: no cover - db.session.rollback() - raise DAOUpdateFailedError(exception=ex) from ex + if item not in db.session: + return db.session.merge(item) return item # type: ignore @classmethod - def delete(cls, items: list[T], commit: bool = True) -> None: + def delete(cls, items: list[T]) -> None: """ Delete the specified items including their associated relationships. @@ -204,17 +179,8 @@ def delete(cls, items: list[T], commit: bool = True) -> None: post-deletion logic. :param items: The items to delete - :param commit: Whether to commit the transaction - :raises DAODeleteFailedError: If the deletion failed :see: https://docs.sqlalchemy.org/en/latest/orm/queryguide/dml.html """ - try: - for item in items: - db.session.delete(item) - - if commit: - db.session.commit() - except SQLAlchemyError as ex: - db.session.rollback() - raise DAODeleteFailedError(exception=ex) from ex + for item in items: + db.session.delete(item) diff --git a/superset/daos/chart.py b/superset/daos/chart.py index 844b36b6b3d4c..35afb7f7a91be 100644 --- a/superset/daos/chart.py +++ b/superset/daos/chart.py @@ -62,7 +62,6 @@ def add_favorite(chart: Slice) -> None: dttm=datetime.now(), ) ) - db.session.commit() @staticmethod def remove_favorite(chart: Slice) -> None: @@ -77,4 +76,3 @@ def remove_favorite(chart: Slice) -> None: ) if fav: db.session.delete(fav) - db.session.commit() diff --git a/superset/daos/dashboard.py b/superset/daos/dashboard.py index 6c973639b73ef..8196c197b2487 100644 --- a/superset/daos/dashboard.py +++ b/superset/daos/dashboard.py @@ -179,8 +179,7 @@ def set_dash_metadata( dashboard: Dashboard, data: dict[Any, Any], old_to_new_slice_ids: dict[int, int] | None = None, - commit: bool = False, - ) -> Dashboard: + ) -> None: new_filter_scopes = {} md = dashboard.params_dict @@ -265,10 +264,6 @@ def set_dash_metadata( md["cross_filters_enabled"] = data.get("cross_filters_enabled", True) dashboard.json_metadata = json.dumps(md) - if commit: - db.session.commit() - return dashboard - @staticmethod def favorited_ids(dashboards: list[Dashboard]) -> list[FavStar]: ids = [dash.id for dash in dashboards] @@ -321,7 +316,6 @@ def copy_dashboard( dash.params = original_dash.params cls.set_dash_metadata(dash, metadata, old_to_new_slice_ids) db.session.add(dash) - db.session.commit() return dash @staticmethod @@ -336,7 +330,6 @@ def add_favorite(dashboard: Dashboard) -> None: dttm=datetime.now(), ) ) - db.session.commit() @staticmethod def remove_favorite(dashboard: Dashboard) -> None: @@ -351,7 +344,6 @@ def remove_favorite(dashboard: Dashboard) -> None: ) if fav: db.session.delete(fav) - db.session.commit() class EmbeddedDashboardDAO(BaseDAO[EmbeddedDashboard]): @@ -369,7 +361,6 @@ def upsert(dashboard: Dashboard, allowed_domains: list[str]) -> EmbeddedDashboar ) embedded.allow_domain_list = ",".join(allowed_domains) dashboard.embedded = [embedded] - db.session.commit() return embedded @classmethod @@ -377,7 +368,6 @@ def create( cls, item: EmbeddedDashboardDAO | None = None, attributes: dict[str, Any] | None = None, - commit: bool = True, ) -> Any: """ Use EmbeddedDashboardDAO.upsert() instead. diff --git a/superset/daos/database.py b/superset/daos/database.py index 15fc03710aa7c..06b429bb6bf1d 100644 --- a/superset/daos/database.py +++ b/superset/daos/database.py @@ -42,7 +42,6 @@ def update( cls, item: Database | None = None, attributes: dict[str, Any] | None = None, - commit: bool = True, ) -> Database: """ Unmask ``encrypted_extra`` before updating. @@ -60,7 +59,7 @@ def update( attributes["encrypted_extra"], ) - return super().update(item, attributes, commit) + return super().update(item, attributes) @staticmethod def validate_uniqueness(database_name: str) -> bool: @@ -174,7 +173,6 @@ def update( cls, item: SSHTunnel | None = None, attributes: dict[str, Any] | None = None, - commit: bool = True, ) -> SSHTunnel: """ Unmask ``password``, ``private_key`` and ``private_key_password`` before updating. @@ -190,7 +188,7 @@ def update( attributes.pop("id", None) attributes = unmask_password_info(attributes, item) - return super().update(item, attributes, commit) + return super().update(item, attributes) class DatabaseUserOAuth2TokensDAO(BaseDAO[DatabaseUserOAuth2Tokens]): diff --git a/superset/daos/dataset.py b/superset/daos/dataset.py index 21c5ae1d0faf0..af1b705d66109 100644 --- a/superset/daos/dataset.py +++ b/superset/daos/dataset.py @@ -25,7 +25,6 @@ from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.daos.base import BaseDAO -from superset.daos.exceptions import DAOUpdateFailedError from superset.extensions import db from superset.models.core import Database from superset.models.dashboard import Dashboard @@ -171,7 +170,6 @@ def update( cls, item: SqlaTable | None = None, attributes: dict[str, Any] | None = None, - commit: bool = True, ) -> SqlaTable: """ Updates a Dataset model on the metadata DB @@ -182,21 +180,19 @@ def update( cls.update_columns( item, attributes.pop("columns"), - commit=commit, override_columns=bool(attributes.get("override_columns")), ) if "metrics" in attributes: - cls.update_metrics(item, attributes.pop("metrics"), commit=commit) + cls.update_metrics(item, attributes.pop("metrics")) - return super().update(item, attributes, commit=commit) + return super().update(item, attributes) @classmethod def update_columns( cls, model: SqlaTable, property_columns: list[dict[str, Any]], - commit: bool = True, override_columns: bool = False, ) -> None: """ @@ -217,7 +213,7 @@ def update_columns( if not DatasetDAO.validate_python_date_format( column["python_date_format"] ): - raise DAOUpdateFailedError( + raise ValueError( "python_date_format is an invalid date/timestamp format." ) @@ -266,15 +262,11 @@ def update_columns( ) ).delete(synchronize_session="fetch") - if commit: - db.session.commit() - @classmethod def update_metrics( cls, model: SqlaTable, property_metrics: list[dict[str, Any]], - commit: bool = True, ) -> None: """ Creates/updates and/or deletes a list of metrics, based on a @@ -317,9 +309,6 @@ def update_metrics( ) ).delete(synchronize_session="fetch") - if commit: - db.session.commit() - @classmethod def find_dataset_column(cls, dataset_id: int, column_id: int) -> TableColumn | None: # We want to apply base dataset filters diff --git a/superset/daos/exceptions.py b/superset/daos/exceptions.py index 6fdd5a80d2c61..ebd20fee631aa 100644 --- a/superset/daos/exceptions.py +++ b/superset/daos/exceptions.py @@ -23,30 +23,6 @@ class DAOException(SupersetException): """ -class DAOCreateFailedError(DAOException): - """ - DAO Create failed - """ - - message = "Create failed" - - -class DAOUpdateFailedError(DAOException): - """ - DAO Update failed - """ - - message = "Update failed" - - -class DAODeleteFailedError(DAOException): - """ - DAO Delete failed - """ - - message = "Delete failed" - - class DatasourceTypeNotSupportedError(DAOException): """ DAO datasource query source type is not supported diff --git a/superset/daos/query.py b/superset/daos/query.py index ea7c82cc34dba..55287ebd9fff3 100644 --- a/superset/daos/query.py +++ b/superset/daos/query.py @@ -53,7 +53,6 @@ def update_saved_query_exec_info(query_id: int) -> None: for saved_query in related_saved_queries: saved_query.rows = query.rows saved_query.last_run = datetime.now() - db.session.commit() @staticmethod def save_metadata(query: Query, payload: dict[str, Any]) -> None: @@ -97,7 +96,6 @@ def stop_query(client_id: str) -> None: query.status = QueryStatus.STOPPED query.end_time = now_as_float() - db.session.commit() class SavedQueryDAO(BaseDAO[SavedQuery]): diff --git a/superset/daos/report.py b/superset/daos/report.py index 4662f325878d6..8cf305c13f26e 100644 --- a/superset/daos/report.py +++ b/superset/daos/report.py @@ -20,10 +20,7 @@ from datetime import datetime from typing import Any -from sqlalchemy.exc import SQLAlchemyError - from superset.daos.base import BaseDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.extensions import db from superset.reports.filters import ReportScheduleFilter from superset.reports.models import ( @@ -137,15 +134,12 @@ def create( cls, item: ReportSchedule | None = None, attributes: dict[str, Any] | None = None, - commit: bool = True, ) -> ReportSchedule: """ Create a report schedule with nested recipients. :param item: The object to create :param attributes: The attributes associated with the object to create - :param commit: Whether to commit the transaction - :raises: DAOCreateFailedError: If the creation failed """ # TODO(john-bodley): Determine why we need special handling for recipients. @@ -165,22 +159,19 @@ def create( for recipient in recipients ] - return super().create(item, attributes, commit) + return super().create(item, attributes) @classmethod def update( cls, item: ReportSchedule | None = None, attributes: dict[str, Any] | None = None, - commit: bool = True, ) -> ReportSchedule: """ Update a report schedule with nested recipients. :param item: The object to update :param attributes: The attributes associated with the object to update - :param commit: Whether to commit the transaction - :raises: DAOUpdateFailedError: If the update failed """ # TODO(john-bodley): Determine why we need special handling for recipients. @@ -200,7 +191,7 @@ def update( for recipient in recipients ] - return super().update(item, attributes, commit) + return super().update(item, attributes) @staticmethod def find_active() -> list[ReportSchedule]: @@ -283,23 +274,12 @@ def find_last_error_notification( return last_error_email_log if not report_from_last_email else None @staticmethod - def bulk_delete_logs( - model: ReportSchedule, - from_date: datetime, - commit: bool = True, - ) -> int | None: - try: - row_count = ( - db.session.query(ReportExecutionLog) - .filter( - ReportExecutionLog.report_schedule == model, - ReportExecutionLog.end_dttm < from_date, - ) - .delete(synchronize_session="fetch") + def bulk_delete_logs(model: ReportSchedule, from_date: datetime) -> int | None: + return ( + db.session.query(ReportExecutionLog) + .filter( + ReportExecutionLog.report_schedule == model, + ReportExecutionLog.end_dttm < from_date, ) - if commit: - db.session.commit() - return row_count - except SQLAlchemyError as ex: - db.session.rollback() - raise DAODeleteFailedError(str(ex)) from ex + .delete(synchronize_session="fetch") + ) diff --git a/superset/daos/tag.py b/superset/daos/tag.py index 46a1d2538f16a..b155cf15c1522 100644 --- a/superset/daos/tag.py +++ b/superset/daos/tag.py @@ -19,12 +19,11 @@ from typing import Any, Optional from flask import g -from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.exc import NoResultFound from superset.commands.tag.exceptions import TagNotFoundError from superset.commands.tag.utils import to_object_type from superset.daos.base import BaseDAO -from superset.daos.exceptions import DAODeleteFailedError from superset.exceptions import MissingUserContextException from superset.extensions import db from superset.models.dashboard import Dashboard @@ -75,7 +74,6 @@ def create_custom_tagged_objects( ) db.session.add_all(tagged_objects) - db.session.commit() @staticmethod def delete_tagged_object( @@ -86,9 +84,7 @@ def delete_tagged_object( """ tag = TagDAO.find_by_name(tag_name.strip()) if not tag: - raise DAODeleteFailedError( - message=f"Tag with name {tag_name} does not exist." - ) + raise NoResultFound(message=f"Tag with name {tag_name} does not exist.") tagged_object = db.session.query(TaggedObject).filter( TaggedObject.tag_id == tag.id, @@ -96,17 +92,13 @@ def delete_tagged_object( TaggedObject.object_id == object_id, ) if not tagged_object: - raise DAODeleteFailedError( + raise NoResultFound( message=f'Tagged object with object_id: {object_id} \ object_type: {object_type} \ and tag name: "{tag_name}" could not be found' ) - try: - db.session.delete(tagged_object.one()) - db.session.commit() - except SQLAlchemyError as ex: # pragma: no cover - db.session.rollback() - raise DAODeleteFailedError(exception=ex) from ex + + db.session.delete(tagged_object.one()) @staticmethod def delete_tags(tag_names: list[str]) -> None: @@ -117,18 +109,12 @@ def delete_tags(tag_names: list[str]) -> None: for name in tag_names: tag_name = name.strip() if not TagDAO.find_by_name(tag_name): - raise DAODeleteFailedError( - message=f"Tag with name {tag_name} does not exist." - ) + raise NoResultFound(message=f"Tag with name {tag_name} does not exist.") tags_to_delete.append(tag_name) tag_objects = db.session.query(Tag).filter(Tag.name.in_(tags_to_delete)) + for tag in tag_objects: - try: - db.session.delete(tag) - db.session.commit() - except SQLAlchemyError as ex: # pragma: no cover - db.session.rollback() - raise DAODeleteFailedError(exception=ex) from ex + db.session.delete(tag) @staticmethod def get_by_name(name: str, type_: TagType = TagType.custom) -> Tag: @@ -283,21 +269,10 @@ def favorite_tag_by_id_for_current_user( # pylint: disable=invalid-name ) -> None: """ Marks a specific tag as a favorite for the current user. - This function will find the tag by the provided id, - create a new UserFavoriteTag object that represents - the user's preference, add that object to the database - session, and commit the session. It uses the currently - authenticated user from the global 'g' object. - Args: - tag_id: The id of the tag that is to be marked as - favorite. - Raises: - Any exceptions raised by the find_by_id function, - the UserFavoriteTag constructor, or the database session's - add and commit methods will propagate up to the caller. - Returns: - None. + + :param tag_id: The id of the tag that is to be marked as favorite """ + tag = TagDAO.find_by_id(tag_id) user = g.user @@ -307,26 +282,13 @@ def favorite_tag_by_id_for_current_user( # pylint: disable=invalid-name raise TagNotFoundError() tag.users_favorited.append(user) - db.session.commit() @staticmethod def remove_user_favorite_tag(tag_id: int) -> None: """ Removes a tag from the current user's favorite tags. - This function will find the tag by the provided id and remove the tag - from the user's list of favorite tags. It uses the currently authenticated - user from the global 'g' object. - - Args: - tag_id: The id of the tag that is to be removed from the favorite tags. - - Raises: - Any exceptions raised by the find_by_id function, the database session's - commit method will propagate up to the caller. - - Returns: - None. + :param tag_id: The id of the tag that is to be removed from the favorite tags """ tag = TagDAO.find_by_id(tag_id) user = g.user @@ -338,9 +300,6 @@ def remove_user_favorite_tag(tag_id: int) -> None: tag.users_favorited.remove(user) - # Commit to save the changes - db.session.commit() - @staticmethod def favorited_ids(tags: list[Tag]) -> list[int]: """ @@ -424,5 +383,4 @@ def create_tag_relationship( object_id, tag.name, ) - db.session.add_all(tagged_objects) diff --git a/superset/daos/user.py b/superset/daos/user.py index cc6696cbdcc74..90a9b2bd2f6e5 100644 --- a/superset/daos/user.py +++ b/superset/daos/user.py @@ -40,4 +40,3 @@ def set_avatar_url(user: User, url: str) -> None: attrs = UserAttribute(avatar_url=url, user_id=user.id) user.extra_attributes = [attrs] db.session.add(attrs) - db.session.commit() diff --git a/superset/dashboards/api.py b/superset/dashboards/api.py index 3fe557a6843bc..823bfdfa8cc8b 100644 --- a/superset/dashboards/api.py +++ b/superset/dashboards/api.py @@ -32,7 +32,7 @@ from werkzeug.wrappers import Response as WerkzeugResponse from werkzeug.wsgi import FileWrapper -from superset import is_feature_enabled, thumbnail_cache +from superset import db, is_feature_enabled, thumbnail_cache from superset.charts.schemas import ChartEntityResponseSchema from superset.commands.dashboard.create import CreateDashboardCommand from superset.commands.dashboard.delete import DeleteDashboardCommand @@ -1314,7 +1314,13 @@ def set_embedded(self, dashboard: Dashboard) -> Response: """ try: body = self.embedded_config_schema.load(request.json) - embedded = EmbeddedDashboardDAO.upsert(dashboard, body["allowed_domains"]) + + with db.session.begin_nested(): + embedded = EmbeddedDashboardDAO.upsert( + dashboard, + body["allowed_domains"], + ) + result = self.embedded_response_schema.dump(embedded) return self.response(200, result=result) except ValidationError as error: diff --git a/superset/databases/api.py b/superset/databases/api.py index 2c0aff8da03da..3a672eb7662b7 100644 --- a/superset/databases/api.py +++ b/superset/databases/api.py @@ -1410,7 +1410,7 @@ def oauth2(self) -> FlaskResponse: database_id=state["database_id"], ) if existing: - DatabaseUserOAuth2TokensDAO.delete([existing], commit=True) + DatabaseUserOAuth2TokensDAO.delete([existing]) # store tokens expiration = datetime.now() + timedelta(seconds=token_response["expires_in"]) @@ -1422,7 +1422,6 @@ def oauth2(self) -> FlaskResponse: "access_token_expiration": expiration, "refresh_token": token_response.get("refresh_token"), }, - commit=True, ) # return blank page that closes itself diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index e876aca8defdf..fd5ec6722ba07 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -455,4 +455,4 @@ def df_to_sql( # pylint: disable=too-many-locals catalog[table.table] = spreadsheet_url database.extra = json.dumps(extra) db.session.add(database) - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 519618aaa6683..e3cf128b7a2c6 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -408,7 +408,7 @@ def handle_cursor( # pylint: disable=too-many-locals logger.info("Query %s: [%s] %s", str(query_id), str(job_id), l) last_log_line = len(log_lines) if needs_commit: - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction if sleep_interval := current_app.config.get("HIVE_POLL_INTERVAL"): logger.warning( "HIVE_POLL_INTERVAL is deprecated and will be removed in 3.0. Please use DB_POLL_INTERVAL_SECONDS instead" diff --git a/superset/db_engine_specs/impala.py b/superset/db_engine_specs/impala.py index 62360e77bbd10..ea74df83164f0 100644 --- a/superset/db_engine_specs/impala.py +++ b/superset/db_engine_specs/impala.py @@ -151,7 +151,7 @@ def handle_cursor(cls, cursor: Any, query: Query) -> None: needs_commit = True if needs_commit: - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction sleep_interval = current_app.config["DB_POLL_INTERVAL_SECONDS"].get( cls.engine, 5 ) diff --git a/superset/db_engine_specs/presto.py b/superset/db_engine_specs/presto.py index 5e0b433e1e11f..fbd0eff484474 100644 --- a/superset/db_engine_specs/presto.py +++ b/superset/db_engine_specs/presto.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint: disable=too-many-lines +# pylint: disable=consider-using-transaction,too-many-lines from __future__ import annotations import contextlib diff --git a/superset/db_engine_specs/trino.py b/superset/db_engine_specs/trino.py index ce0e03be77501..143276bdc3dca 100644 --- a/superset/db_engine_specs/trino.py +++ b/superset/db_engine_specs/trino.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=consider-using-transaction from __future__ import annotations import contextlib diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py index 8b3b315226528..c1a0897eb3f79 100644 --- a/superset/examples/bart_lines.py +++ b/superset/examples/bart_lines.py @@ -65,5 +65,4 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None: tbl.description = "BART lines" tbl.database = database tbl.filter_select_enabled = True - db.session.commit() tbl.fetch_metadata() diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 229734057ceb9..81e31e7416558 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -111,8 +111,6 @@ def load_birth_names( _set_table_metadata(obj, database) _add_table_metrics(obj) - db.session.commit() - slices, _ = create_slices(obj) create_dashboard(slices) @@ -844,5 +842,4 @@ def create_dashboard(slices: list[Slice]) -> Dashboard: dash.dashboard_title = "USA Births Names" dash.position_json = json.dumps(pos, indent=4) dash.slug = "births" - db.session.commit() return dash diff --git a/superset/examples/country_map.py b/superset/examples/country_map.py index 1741219470ac3..53f4a0b874ff3 100644 --- a/superset/examples/country_map.py +++ b/superset/examples/country_map.py @@ -88,7 +88,6 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N if not any(col.metric_name == "avg__2004" for col in obj.metrics): col = str(column("2004").compile(db.engine)) obj.metrics.append(SqlMetric(metric_name="avg__2004", expression=f"AVG({col})")) - db.session.commit() obj.fetch_metadata() tbl = obj diff --git a/superset/examples/css_templates.py b/superset/examples/css_templates.py index 2f67d2e1faac9..91bb54c157752 100644 --- a/superset/examples/css_templates.py +++ b/superset/examples/css_templates.py @@ -52,7 +52,6 @@ def load_css_templates() -> None: """ ) obj.css = css - db.session.commit() obj = db.session.query(CssTemplate).filter_by(template_name="Courier Black").first() if not obj: @@ -97,4 +96,3 @@ def load_css_templates() -> None: """ ) obj.css = css - db.session.commit() diff --git a/superset/examples/deck.py b/superset/examples/deck.py index b0cb65b03fc2a..931924dd0879b 100644 --- a/superset/examples/deck.py +++ b/superset/examples/deck.py @@ -541,4 +541,3 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements dash.dashboard_title = title dash.slug = slug dash.slices = slices - db.session.commit() diff --git a/superset/examples/energy.py b/superset/examples/energy.py index 98b444f9db2f6..d7e46ec5d8c31 100644 --- a/superset/examples/energy.py +++ b/superset/examples/energy.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Loads datasets, dashboards and slices in a new superset instance""" - import textwrap import pandas as pd @@ -79,7 +77,6 @@ def load_energy( SqlMetric(metric_name="sum__value", expression=f"SUM({col})") ) - db.session.commit() tbl.fetch_metadata() slc = Slice( diff --git a/superset/examples/flights.py b/superset/examples/flights.py index 4db029519fd8b..f8659c24d07fc 100644 --- a/superset/examples/flights.py +++ b/superset/examples/flights.py @@ -66,6 +66,5 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None: tbl.description = "Random set of flights in the US" tbl.database = database tbl.filter_select_enabled = True - db.session.commit() tbl.fetch_metadata() print("Done loading table!") diff --git a/superset/examples/helpers.py b/superset/examples/helpers.py index b865e2dfca935..4cc9a47b2700f 100644 --- a/superset/examples/helpers.py +++ b/superset/examples/helpers.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Loads datasets, dashboards and slices in a new superset instance""" - import os from typing import Any @@ -62,7 +60,6 @@ def merge_slice(slc: Slice) -> None: if o: db.session.delete(o) db.session.add(slc) - db.session.commit() def get_slice_json(defaults: dict[Any, Any], **kwargs: Any) -> str: diff --git a/superset/examples/long_lat.py b/superset/examples/long_lat.py index 4f8de31453c18..5afb65f6fd91f 100644 --- a/superset/examples/long_lat.py +++ b/superset/examples/long_lat.py @@ -97,7 +97,6 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None obj.main_dttm_col = "datetime" obj.database = database obj.filter_select_enabled = True - db.session.commit() obj.fetch_metadata() tbl = obj diff --git a/superset/examples/misc_dashboard.py b/superset/examples/misc_dashboard.py index 825dc6352c8e3..4a7079e2cddc3 100644 --- a/superset/examples/misc_dashboard.py +++ b/superset/examples/misc_dashboard.py @@ -140,4 +140,3 @@ def load_misc_dashboard() -> None: dash.position_json = json.dumps(pos, indent=4) dash.slug = DASH_SLUG dash.slices = slices - db.session.commit() diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py index 979be10686f5a..9cfe44c1994c1 100644 --- a/superset/examples/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -102,7 +102,6 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals col.python_date_format = dttm_and_expr[0] col.database_expression = dttm_and_expr[1] col.is_dttm = True - db.session.commit() obj.fetch_metadata() tbl = obj diff --git a/superset/examples/paris.py b/superset/examples/paris.py index 990aa01ca6c30..928e2294072a4 100644 --- a/superset/examples/paris.py +++ b/superset/examples/paris.py @@ -62,5 +62,4 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> tbl.description = "Map of Paris" tbl.database = database tbl.filter_select_enabled = True - db.session.commit() tbl.fetch_metadata() diff --git a/superset/examples/random_time_series.py b/superset/examples/random_time_series.py index ec232995fa2e7..10ece826b6a1d 100644 --- a/superset/examples/random_time_series.py +++ b/superset/examples/random_time_series.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import pandas as pd from sqlalchemy import DateTime, inspect, String @@ -72,7 +71,6 @@ def load_random_time_series_data( obj.main_dttm_col = "ds" obj.database = database obj.filter_select_enabled = True - db.session.commit() obj.fetch_metadata() tbl = obj diff --git a/superset/examples/sf_population_polygons.py b/superset/examples/sf_population_polygons.py index 4fa59db721a69..b8d5527ed247b 100644 --- a/superset/examples/sf_population_polygons.py +++ b/superset/examples/sf_population_polygons.py @@ -64,5 +64,4 @@ def load_sf_population_polygons( tbl.description = "Population density of San Francisco" tbl.database = database tbl.filter_select_enabled = True - db.session.commit() tbl.fetch_metadata() diff --git a/superset/examples/supported_charts_dashboard.py b/superset/examples/supported_charts_dashboard.py index 49141eb73cf62..c605bf88cc571 100644 --- a/superset/examples/supported_charts_dashboard.py +++ b/superset/examples/supported_charts_dashboard.py @@ -14,9 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - # pylint: disable=too-many-lines - import textwrap from sqlalchemy import inspect @@ -1274,4 +1272,3 @@ def load_supported_charts_dashboard() -> None: dash.dashboard_title = "Supported Charts Dashboard" dash.position_json = json.dumps(pos, indent=2) dash.slug = DASH_SLUG - db.session.commit() diff --git a/superset/examples/tabbed_dashboard.py b/superset/examples/tabbed_dashboard.py index bbc11e77306ae..b44c2a6d2be91 100644 --- a/superset/examples/tabbed_dashboard.py +++ b/superset/examples/tabbed_dashboard.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Loads datasets, dashboards and slices in a new superset instance""" - import textwrap from superset import db @@ -558,4 +556,3 @@ def load_tabbed_dashboard(_: bool = False) -> None: dash.slices = slices dash.dashboard_title = "Tabbed Dashboard" dash.slug = slug - db.session.commit() diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index afbb6a994a831..a9c06dfa2942a 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -14,8 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -"""Loads datasets, dashboards and slices in a new superset instance""" - import os import pandas as pd @@ -41,7 +39,7 @@ from superset.utils.core import DatasourceType -def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-statements +def load_world_bank_health_n_pop( # pylint: disable=too-many-locals only_metadata: bool = False, force: bool = False, sample: bool = False, @@ -110,7 +108,6 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s SqlMetric(metric_name=metric, expression=f"{aggr_func}({col})") ) - db.session.commit() tbl.fetch_metadata() slices = create_slices(tbl) @@ -134,7 +131,6 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s dash.position_json = json.dumps(pos, indent=4) dash.slug = slug dash.slices = slices - db.session.commit() def create_slices(tbl: BaseDatasource) -> list[Slice]: diff --git a/superset/extensions/metastore_cache.py b/superset/extensions/metastore_cache.py index 7b4e39677e48f..1c89e8459774d 100644 --- a/superset/extensions/metastore_cache.py +++ b/superset/extensions/metastore_cache.py @@ -22,7 +22,6 @@ from flask import current_app, Flask, has_app_context from flask_caching import BaseCache -from superset import db from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.types import ( KeyValueCodec, @@ -95,7 +94,6 @@ def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: codec=self.codec, expires_on=self._get_expiry(timeout), ).run() - db.session.commit() return True def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: @@ -111,7 +109,6 @@ def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: key=self.get_key(key), expires_on=self._get_expiry(timeout), ).run() - db.session.commit() return True except KeyValueCreateFailedError: return False @@ -136,6 +133,4 @@ def delete(self, key: str) -> Any: # pylint: disable=import-outside-toplevel from superset.commands.key_value.delete import DeleteKeyValueCommand - ret = DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run() - db.session.commit() - return ret + return DeleteKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run() diff --git a/superset/extensions/pylint.py b/superset/extensions/pylint.py index 1cf9821f44606..5925f180b785f 100644 --- a/superset/extensions/pylint.py +++ b/superset/extensions/pylint.py @@ -56,5 +56,22 @@ def visit_importfrom(self, node: nodes.ImportFrom) -> None: self.add_message("disallowed-import", node=node) +class TransactionChecker(BaseChecker): + name = "consider-using-transaction" + msgs = { + "W0001": ( + 'Consider using the @transaction decorator when defining a "unit of work"', + "consider-using-transaction", + "Used when an explicit commit or rollback call is detected", + ), + } + + def visit_call(self, node: nodes.Call) -> None: + if isinstance(node.func, nodes.Attribute): + if node.func.attrname in ("commit", "rollback"): + self.add_message("consider-using-transaction", node=node) + + def register(linter: PyLinter) -> None: linter.register_checker(JSONLibraryImportChecker(linter)) + linter.register_checker(TransactionChecker(linter)) diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index a98b8c94892dd..f074eaf293d53 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -56,6 +56,7 @@ from superset.superset_typing import FlaskResponse from superset.tags.core import register_sqla_event_listeners from superset.utils.core import is_test, pessimistic_connection_handling +from superset.utils.decorators import transaction from superset.utils.log import DBEventLogger, get_event_logger_from_cfg_value if TYPE_CHECKING: @@ -513,6 +514,7 @@ def configure_cache(self) -> None: def configure_feature_flags(self) -> None: feature_flag_manager.init_app(self.superset_app) + @transaction() def configure_fab(self) -> None: if self.config["SILENCE_FAB"]: logging.getLogger("flask_appbuilder").setLevel(logging.ERROR) diff --git a/superset/key_value/shared_entries.py b/superset/key_value/shared_entries.py index f472838d2e090..130313157a53d 100644 --- a/superset/key_value/shared_entries.py +++ b/superset/key_value/shared_entries.py @@ -18,7 +18,6 @@ from typing import Any, Optional from uuid import uuid3 -from superset import db from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey from superset.key_value.utils import get_uuid_namespace, random_key @@ -46,7 +45,6 @@ def set_shared_value(key: SharedKey, value: Any) -> None: key=uuid_key, codec=CODEC, ).run() - db.session.commit() def get_permalink_salt(key: SharedKey) -> str: diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index c2048f2a556c1..28d8aacc7bed9 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -83,7 +83,7 @@ def copy_dashboard(_mapper: Mapper, _connection: Connection, target: Dashboard) user_id=target.id, welcome_dashboard_id=dashboard.id ) session.add(extra_attributes) - session.commit() + session.commit() # pylint: disable=consider-using-transaction sqla.event.listen(User, "after_insert", copy_dashboard) diff --git a/superset/queries/api.py b/superset/queries/api.py index 0695946fe07f3..67afd8a81763c 100644 --- a/superset/queries/api.py +++ b/superset/queries/api.py @@ -231,8 +231,8 @@ def get_updated_since(self, **kwargs: Any) -> FlaskResponse: backoff.constant, Exception, interval=1, - on_backoff=lambda details: db.session.rollback(), - on_giveup=lambda details: db.session.rollback(), + on_backoff=lambda details: db.session.rollback(), # pylint: disable=consider-using-transaction + on_giveup=lambda details: db.session.rollback(), # pylint: disable=consider-using-transaction max_tries=5, ) @requires_json diff --git a/superset/row_level_security/api.py b/superset/row_level_security/api.py index 86956683cb15e..077d55ff4ebb7 100644 --- a/superset/row_level_security/api.py +++ b/superset/row_level_security/api.py @@ -23,6 +23,7 @@ from flask_appbuilder.models.sqla.interface import SQLAInterface from flask_babel import ngettext from marshmallow import ValidationError +from sqlalchemy.exc import SQLAlchemyError from superset.commands.exceptions import ( DatasourceNotFoundValidationError, @@ -34,7 +35,6 @@ from superset.commands.security.update import UpdateRLSRuleCommand from superset.connectors.sqla.models import RowLevelSecurityFilter from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod -from superset.daos.exceptions import DAOCreateFailedError, DAOUpdateFailedError from superset.extensions import event_logger from superset.row_level_security.schemas import ( get_delete_ids_schema, @@ -205,7 +205,7 @@ def post(self) -> Response: exc_info=True, ) return self.response_422(message=str(ex)) - except DAOCreateFailedError as ex: + except SQLAlchemyError as ex: logger.error( "Error creating RLS rule %s: %s", self.__class__.__name__, @@ -291,7 +291,7 @@ def put(self, pk: int) -> Response: exc_info=True, ) return self.response_422(message=str(ex)) - except DAOUpdateFailedError as ex: + except SQLAlchemyError as ex: logger.error( "Error updating RLS rule %s: %s", self.__class__.__name__, diff --git a/superset/security/manager.py b/superset/security/manager.py index ea2ee5ef83a12..b4bc0c6103def 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -1019,7 +1019,6 @@ def clean_perms(self) -> None: ) if deleted_count := pvms.delete(): logger.info("Deleted %i faulty permissions", deleted_count) - self.get_session.commit() def sync_role_definitions(self) -> None: """ @@ -1045,7 +1044,6 @@ def sync_role_definitions(self) -> None: self.auth_role_public, merge=True, ) - self.create_missing_perms() self.clean_perms() @@ -1119,7 +1117,6 @@ def copy_role( ): role_from_permissions.append(permission_view) role_to.permissions = role_from_permissions - self.get_session.commit() def set_role( self, @@ -1140,7 +1137,6 @@ def set_role( permission_view for permission_view in pvms if pvm_check(permission_view) ] role.permissions = role_pvms - self.get_session.commit() def _is_admin_only(self, pvm: PermissionView) -> bool: """ diff --git a/superset/sql_lab.py b/superset/sql_lab.py index cb2cbe455cce1..9712ab47ab426 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=consider-using-transaction import dataclasses import logging import uuid @@ -127,6 +128,7 @@ def handle_query_error( def get_query_backoff_handler(details: dict[Any, Any]) -> None: + print(details) query_id = details["kwargs"]["query_id"] logger.error( "Query with id `%s` could not be retrieved", str(query_id), exc_info=True diff --git a/superset/sqllab/sql_json_executer.py b/superset/sqllab/sql_json_executer.py index fde73aef0a86e..ac9968ed6b467 100644 --- a/superset/sqllab/sql_json_executer.py +++ b/superset/sqllab/sql_json_executer.py @@ -90,6 +90,7 @@ def execute( rendered_query: str, log_params: dict[str, Any] | None, ) -> SqlJsonExecutionStatus: + print(">>> execute <<<") query_id = execution_context.query.id try: data = self._get_sql_results_with_timeout( @@ -101,6 +102,7 @@ def execute( raise except Exception as ex: logger.exception("Query %i failed unexpectedly", query_id) + print(str(ex)) raise SupersetGenericDBErrorException( utils.error_msg_from_exception(ex) ) from ex @@ -112,6 +114,7 @@ def execute( [SupersetError(**params) for params in data["errors"]] # type: ignore ) # old string-only error message + print(data) raise SupersetGenericDBErrorException(data["error"]) # type: ignore return SqlJsonExecutionStatus.HAS_RESULTS diff --git a/superset/tags/models.py b/superset/tags/models.py index ba859f519bf47..31975c3e8e882 100644 --- a/superset/tags/models.py +++ b/superset/tags/models.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=consider-using-transaction from __future__ import annotations import enum diff --git a/superset/tasks/celery_app.py b/superset/tasks/celery_app.py index 4d36917be0bb8..5a0963ccd544b 100644 --- a/superset/tasks/celery_app.py +++ b/superset/tasks/celery_app.py @@ -62,7 +62,7 @@ def teardown( # pylint: disable=unused-argument if flask_app.config.get("SQLALCHEMY_COMMIT_ON_TEARDOWN"): if not isinstance(retval, Exception): - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction if not flask_app.config.get("CELERY_ALWAYS_EAGER"): db.session.remove() diff --git a/superset/utils/database.py b/superset/utils/database.py index 073e58ffda6fb..719e7f2d772c7 100644 --- a/superset/utils/database.py +++ b/superset/utils/database.py @@ -54,13 +54,12 @@ def get_or_create_db( ) db.session.add(database) database.set_sqlalchemy_uri(sqlalchemy_uri) - db.session.commit() # todo: it's a bad idea to do an update in a get/create function if database and database.sqlalchemy_uri_decrypted != sqlalchemy_uri: database.set_sqlalchemy_uri(sqlalchemy_uri) - db.session.commit() + db.session.flush() return database @@ -80,4 +79,4 @@ def remove_database(database: Database) -> None: from superset import db db.session.delete(database) - db.session.commit() + db.session.flush() diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index 3900bdd4156a1..844a8f063c1b8 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -20,10 +20,12 @@ import time from collections.abc import Iterator from contextlib import contextmanager +from functools import wraps from typing import Any, Callable, TYPE_CHECKING from uuid import UUID from flask import current_app, g, Response +from sqlalchemy.exc import SQLAlchemyError from superset.utils import core as utils from superset.utils.dates import now_as_float @@ -207,3 +209,64 @@ def suppress_logging( yield finally: target_logger.setLevel(original_level) + + +def on_error( + ex: Exception, + catches: tuple[type[Exception], ...] = (SQLAlchemyError,), + reraise: type[Exception] | None = SQLAlchemyError, +) -> None: + """ + Default error handler whenever any exception is caught during a SQLAlchemy nested + transaction. + + :param ex: The source exception + :param catches: The exception types the handler catches + :param reraise: The exception type the handler raises after catching + :raises Exception: If the exception is not swallowed + """ + + if isinstance(ex, catches): + if hasattr(ex, "exception"): + logger.exception(ex.exception) + + if reraise: + raise reraise() from ex + else: + raise ex + + +def transaction( # pylint: disable=redefined-outer-name + on_error: Callable[..., Any] | None = on_error, +) -> Callable[..., Any]: + """ + Perform a "unit of work". + + Note ideally this would leverage SQLAlchemy's nested transaction, however this + proved rather complicated, likely due to many architectural facets, and thus has + been left for a follow up exercise. + + :param on_error: Callback invoked when an exception is caught + :see: https://github.com/apache/superset/issues/25108 + """ + + def decorate(func: Callable[..., Any]) -> Callable[..., Any]: + @wraps(func) + def wrapped(*args: Any, **kwargs: Any) -> Any: + from superset import db # pylint: disable=import-outside-toplevel + + try: + result = func(*args, **kwargs) + db.session.commit() # pylint: disable=consider-using-transaction + return result + except Exception as ex: + db.session.rollback() # pylint: disable=consider-using-transaction + + if on_error: + return on_error(ex) + + raise + + return wrapped + + return decorate diff --git a/superset/utils/lock.py b/superset/utils/lock.py index 3cd3c8ead53ab..4723b57fa1b01 100644 --- a/superset/utils/lock.py +++ b/superset/utils/lock.py @@ -24,7 +24,6 @@ from datetime import datetime, timedelta from typing import Any, cast, TypeVar, Union -from superset import db from superset.exceptions import CreateKeyValueDistributedLockFailedException from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.types import JsonKeyValueCodec, KeyValueResource @@ -72,7 +71,6 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name store. :param namespace: The namespace for which the lock is to be acquired. - :type namespace: str :param kwargs: Additional keyword arguments. :yields: A unique identifier (UUID) for the acquired lock (the KV key). :raises CreateKeyValueDistributedLockFailedException: If the lock is taken. @@ -93,12 +91,10 @@ def KeyValueDistributedLock( # pylint: disable=invalid-name value=True, expires_on=datetime.now() + LOCK_EXPIRATION, ).run() - db.session.commit() yield key DeleteKeyValueCommand(resource=KeyValueResource.LOCK, key=key).run() - db.session.commit() logger.debug("Removed lock on namespace %s for key %s", namespace, key) except KeyValueCreateFailedError as ex: raise CreateKeyValueDistributedLockFailedException( diff --git a/superset/utils/log.py b/superset/utils/log.py index 4b9ebb50b989b..71c552883307d 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -403,7 +403,7 @@ def log( # pylint: disable=too-many-arguments,too-many-locals logs.append(log) try: db.session.bulk_save_objects(logs) - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction except SQLAlchemyError as ex: logging.error("DBEventLogger failed to log event(s)") logging.exception(ex) diff --git a/superset/views/base.py b/superset/views/base.py index dc90b07728975..b836c5076c4e3 100644 --- a/superset/views/base.py +++ b/superset/views/base.py @@ -63,6 +63,7 @@ app as superset_app, appbuilder, conf, + db, get_feature_flags, is_feature_enabled, security_manager, @@ -698,7 +699,7 @@ def _delete(self: BaseView, primary_key: int) -> None: if view_menu: security_manager.get_session.delete(view_menu) - security_manager.get_session.commit() + db.session.commit() # pylint: disable=consider-using-transaction flash(*self.datamodel.message) self.update_redirect() diff --git a/superset/views/core.py b/superset/views/core.py index 5f76b05a78817..75a04dedc7a81 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -619,10 +619,12 @@ def save_or_overwrite_slice( if action == "saveas" and slice_add_perm: ChartDAO.create(slc) + db.session.commit() # pylint: disable=consider-using-transaction msg = _("Chart [{}] has been saved").format(slc.slice_name) flash(msg, "success") elif action == "overwrite" and slice_overwrite_perm: ChartDAO.update(slc) + db.session.commit() # pylint: disable=consider-using-transaction msg = _("Chart [{}] has been overwritten").format(slc.slice_name) flash(msg, "success") @@ -676,7 +678,7 @@ def save_or_overwrite_slice( if dash and slc not in dash.slices: dash.slices.append(slc) - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction response = { "can_add": slice_add_perm, diff --git a/superset/views/dashboard/views.py b/superset/views/dashboard/views.py index 2e88b4acd02b7..8a419fcb26f97 100644 --- a/superset/views/dashboard/views.py +++ b/superset/views/dashboard/views.py @@ -122,7 +122,7 @@ def new(self) -> FlaskResponse: owners=[g.user], ) db.session.add(new_dashboard) - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction return redirect(f"/superset/dashboard/{new_dashboard.id}/?edit=true") @expose("//embedded") diff --git a/superset/views/datasource/views.py b/superset/views/datasource/views.py index 89907df000fa1..377579cf05d16 100644 --- a/superset/views/datasource/views.py +++ b/superset/views/datasource/views.py @@ -116,7 +116,7 @@ def save(self) -> FlaskResponse: ) orm_datasource.update_from_object(datasource_dict) data = orm_datasource.data - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction return self.json_response(sanitize_datasource_data(data)) diff --git a/superset/views/key_value.py b/superset/views/key_value.py index 3ba53073c7047..69a5314c5fb57 100644 --- a/superset/views/key_value.py +++ b/superset/views/key_value.py @@ -48,7 +48,7 @@ def store(self) -> FlaskResponse: value = request.form.get("data") obj = models.KeyValue(value=value) db.session.add(obj) - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction except Exception as ex: # pylint: disable=broad-except return json_error_response(utils.error_msg_from_exception(ex)) return Response(json.dumps({"id": obj.id}), status=200) diff --git a/superset/views/sql_lab/views.py b/superset/views/sql_lab/views.py index 693299118d08a..3ec3667267471 100644 --- a/superset/views/sql_lab/views.py +++ b/superset/views/sql_lab/views.py @@ -14,6 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +# pylint: disable=consider-using-transaction import logging from flask import request, Response @@ -272,6 +273,5 @@ def expanded(self, table_schema_id: int) -> FlaskResponse: .filter_by(id=table_schema_id) .update({"expanded": payload}) ) - db.session.commit() response = json.dumps({"id": table_schema_id, "expanded": payload}) return json_success(response) diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 77633d65642e9..0e407b86573d4 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -203,8 +203,7 @@ def temporary_user( previous_g_user = g.user if hasattr(g, "user") else None try: if login: - resp = self.login(username=temp_user.username) - print(resp) + self.login(username=temp_user.username) else: g.user = temp_user yield temp_user diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index 6d25fe81905a1..a9af7c12b3994 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -1266,7 +1266,6 @@ def test_admin_gets_filtered_energy_slices(self): assert rv.status_code == 200 assert data["count"] > 0 for chart in data["result"]: - print(chart) assert ( "energy" in " ".join( diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 58cfd9d494cb9..56b0a9a793b0c 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -1211,6 +1211,9 @@ def test_chart_data_cache_no_login(self, cache_loader): """ Chart data cache API: Test chart data async cache request (no login) """ + if get_example_database().backend == "presto": + return + app._got_first_request = False async_query_manager_factory.init_app(app) self.logout() diff --git a/tests/integration_tests/conftest.py b/tests/integration_tests/conftest.py index f180da9aed8b8..537c1c882e0ce 100644 --- a/tests/integration_tests/conftest.py +++ b/tests/integration_tests/conftest.py @@ -124,10 +124,6 @@ def setup_sample_data() -> Any: with app.app_context(): setup_presto_if_needed() - from superset.cli.test import load_test_users_run - - load_test_users_run() - from superset.examples.css_templates import load_css_templates load_css_templates() diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 9166d549588c0..44b7ef26e64cd 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -814,7 +814,7 @@ def set(self): mock_cache.return_value = MockCache() rv = self.client.get("/superset/explore_json/data/valid-cache-key") - self.assertEqual(rv.status_code, 401) + self.assertEqual(rv.status_code, 403) def test_explore_json_data_invalid_cache_key(self): self.login(ADMIN_USERNAME) diff --git a/tests/integration_tests/dashboard_tests.py b/tests/integration_tests/dashboard_tests.py index 1852adba48af3..bee8de7a5e064 100644 --- a/tests/integration_tests/dashboard_tests.py +++ b/tests/integration_tests/dashboard_tests.py @@ -186,7 +186,11 @@ def test_dashboard_with_created_by_can_be_accessed_by_public_users(self): # Cleanup self.revoke_public_access_to_table(table) - @pytest.mark.usefixtures("load_energy_table_with_slice", "load_dashboard") + @pytest.mark.usefixtures( + "public_role_like_gamma", + "load_energy_table_with_slice", + "load_dashboard", + ) def test_users_can_list_published_dashboard(self): self.login(ALPHA_USERNAME) resp = self.get_resp("/api/v1/dashboard/") diff --git a/tests/integration_tests/dashboards/commands_tests.py b/tests/integration_tests/dashboards/commands_tests.py index 06edd6c6d0f18..334e0425cf1f3 100644 --- a/tests/integration_tests/dashboards/commands_tests.py +++ b/tests/integration_tests/dashboards/commands_tests.py @@ -592,7 +592,6 @@ def test_import_v1_dashboard_multiple(self, mock_g): } command = v1.ImportDashboardsCommand(contents, overwrite=True) command.run() - command.run() new_num_dashboards = db.session.query(Dashboard).count() assert new_num_dashboards == num_dashboards + 1 diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index d4a1ac08c21c4..8d0cd0810f8b1 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -281,7 +281,6 @@ def test_create_database(self): "server_cert": None, "extra": json.dumps(extra), } - uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) @@ -713,7 +712,6 @@ def test_cascade_delete_ssh_tunnel( "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, "ssh_tunnel": ssh_tunnel_properties, } - uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) @@ -923,7 +921,6 @@ def test_create_database_invalid_configuration_method(self): "server_cert": None, "extra": json.dumps(extra), } - uri = "api/v1/database/" rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 59277a5bb6dd6..37de6e87c27ad 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -26,17 +26,13 @@ import pytest import yaml from sqlalchemy import inspect +from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import joinedload from sqlalchemy.sql import func from superset import app # noqa: F401 from superset.commands.dataset.exceptions import DatasetCreateFailedError from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn -from superset.daos.exceptions import ( - DAOCreateFailedError, - DAODeleteFailedError, - DAOUpdateFailedError, -) from superset.extensions import db, security_manager from superset.models.core import Database from superset.models.slice import Slice @@ -197,7 +193,6 @@ def test_user_gets_all_datasets(self): def count_datasets(): uri = "api/v1/chart/" rv = self.client.get(uri, "get_list") - print(rv.data) self.assertEqual(rv.status_code, 200) data = rv.get_json() return data["count"] @@ -879,7 +874,7 @@ def test_create_dataset_sqlalchemy_error(self, mock_dao_create): Dataset API: Test create dataset sqlalchemy error """ - mock_dao_create.side_effect = DAOCreateFailedError() + mock_dao_create.side_effect = SQLAlchemyError() self.login(ADMIN_USERNAME) main_db = get_main_database() dataset_data = { @@ -1487,7 +1482,7 @@ def test_update_dataset_sqlalchemy_error(self, mock_dao_update): Dataset API: Test update dataset sqlalchemy error """ - mock_dao_update.side_effect = DAOUpdateFailedError() + mock_dao_update.side_effect = SQLAlchemyError() dataset = self.insert_default_dataset() self.login(ADMIN_USERNAME) @@ -1551,7 +1546,7 @@ def test_delete_dataset_sqlalchemy_error(self, mock_dao_delete): Dataset API: Test delete dataset sqlalchemy error """ - mock_dao_delete.side_effect = DAODeleteFailedError() + mock_dao_delete.side_effect = SQLAlchemyError() dataset = self.insert_default_dataset() self.login(ADMIN_USERNAME) @@ -1620,7 +1615,7 @@ def test_delete_dataset_column_fail(self, mock_dao_delete): Dataset API: Test delete dataset column """ - mock_dao_delete.side_effect = DAODeleteFailedError() + mock_dao_delete.side_effect = SQLAlchemyError() dataset = self.get_fixture_datasets()[0] column_id = dataset.columns[0].id self.login(ADMIN_USERNAME) @@ -1692,7 +1687,7 @@ def test_delete_dataset_metric_fail(self, mock_dao_delete): Dataset API: Test delete dataset metric """ - mock_dao_delete.side_effect = DAODeleteFailedError() + mock_dao_delete.side_effect = SQLAlchemyError() dataset = self.get_fixture_datasets()[0] column_id = dataset.metrics[0].id self.login(ADMIN_USERNAME) diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 718b6d2d9835d..aaad26b85d723 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -88,7 +88,6 @@ def test_external_metadata_for_physical_table(self): ) def test_always_filter_main_dttm(self): - self.login(ADMIN_USERNAME) database = get_example_database() sql = f"SELECT DATE() as default_dttm, DATE() as additional_dttm, 1 as metric;" # noqa: F541 @@ -363,7 +362,6 @@ def test_save(self): elif k == "owners": self.assertEqual([o["id"] for o in resp[k]], datasource_post["owners"]) else: - print(k) self.assertEqual(resp[k], datasource_post[k]) def test_save_default_endpoint_validation_success(self): diff --git a/tests/integration_tests/embedded/api_tests.py b/tests/integration_tests/embedded/api_tests.py index 533f1311d3d6c..64afaa178496a 100644 --- a/tests/integration_tests/embedded/api_tests.py +++ b/tests/integration_tests/embedded/api_tests.py @@ -44,6 +44,7 @@ def test_get_embedded_dashboard(self): self.login(ADMIN_USERNAME) self.dash = db.session.query(Dashboard).filter_by(slug="births").first() self.embedded = EmbeddedDashboardDAO.upsert(self.dash, []) + db.session.flush() uri = f"api/v1/{self.resource_name}/{self.embedded.uuid}" response = self.client.get(uri) self.assert200(response) diff --git a/tests/integration_tests/embedded/dao_tests.py b/tests/integration_tests/embedded/dao_tests.py index e1f72feb89db8..eed161581fe71 100644 --- a/tests/integration_tests/embedded/dao_tests.py +++ b/tests/integration_tests/embedded/dao_tests.py @@ -34,17 +34,21 @@ def test_upsert(self): dash = db.session.query(Dashboard).filter_by(slug="world_health").first() assert not dash.embedded EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) + db.session.flush() assert dash.embedded self.assertEqual(dash.embedded[0].allowed_domains, ["test.example.com"]) original_uuid = dash.embedded[0].uuid self.assertIsNotNone(original_uuid) EmbeddedDashboardDAO.upsert(dash, []) + db.session.flush() self.assertEqual(dash.embedded[0].allowed_domains, []) self.assertEqual(dash.embedded[0].uuid, original_uuid) @pytest.mark.usefixtures("load_world_bank_dashboard_with_slices") def test_get_by_uuid(self): dash = db.session.query(Dashboard).filter_by(slug="world_health").first() - uuid = str(EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]).uuid) + EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) + db.session.flush() + uuid = str(dash.embedded[0].uuid) embedded = EmbeddedDashboardDAO.find_by_id(uuid) self.assertIsNotNone(embedded) diff --git a/tests/integration_tests/embedded/test_view.py b/tests/integration_tests/embedded/test_view.py index 7fcfcdba9ff0e..f4d5ae6925568 100644 --- a/tests/integration_tests/embedded/test_view.py +++ b/tests/integration_tests/embedded/test_view.py @@ -44,6 +44,7 @@ def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811 dash = db.session.query(Dashboard).filter_by(slug="births").first() embedded = EmbeddedDashboardDAO.upsert(dash, []) + db.session.flush() uri = f"embedded/{embedded.uuid}" response = client.get(uri) assert response.status_code == 200 @@ -57,6 +58,7 @@ def test_get_embedded_dashboard(client: FlaskClient[Any]): # noqa: F811 def test_get_embedded_dashboard_referrer_not_allowed(client: FlaskClient[Any]): # noqa: F811 dash = db.session.query(Dashboard).filter_by(slug="births").first() embedded = EmbeddedDashboardDAO.upsert(dash, ["test.example.com"]) + db.session.flush() uri = f"embedded/{embedded.uuid}" response = client.get(uri) assert response.status_code == 403 diff --git a/tests/integration_tests/fixtures/unicode_dashboard.py b/tests/integration_tests/fixtures/unicode_dashboard.py index e68e8f079944f..970845783058c 100644 --- a/tests/integration_tests/fixtures/unicode_dashboard.py +++ b/tests/integration_tests/fixtures/unicode_dashboard.py @@ -114,7 +114,8 @@ def _create_and_commit_unicode_slice(table: SqlaTable, title: str): def _cleanup(dash: Dashboard, slice_name: str) -> None: db.session.delete(dash) - if slice_name: - slice = db.session.query(Slice).filter_by(slice_name=slice_name).one_or_none() + if slice_name and ( + slice := db.session.query(Slice).filter_by(slice_name=slice_name).one_or_none() + ): db.session.delete(slice) db.session.commit() diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 2c8a13a71f4d6..71bb1484e0330 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -215,8 +215,6 @@ def test_model_view_rls_add_name_unique(self): }, ) self.assertEqual(rv.status_code, 422) - data = json.loads(rv.data.decode("utf-8")) - assert "Create failed" in data["message"] @pytest.mark.usefixtures("create_dataset") def test_model_view_rls_add_tables_required(self): diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index f5569b1c83912..86fffee1ec89a 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -543,8 +543,7 @@ def test_fetch_metadata_for_updated_virtual_table(self): # make sure the columns have been mapped properly assert len(table.columns) == 4 - with db.session.no_autoflush: - table.fetch_metadata(commit=False) + table.fetch_metadata() # assert that the removed column has been dropped and # the physical and calculated columns are present diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index a36cb8a8ec35a..829854d966810 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -73,7 +73,6 @@ class TestSqlLab(SupersetTestCase): def run_some_queries(self): db.session.query(Query).delete() - db.session.commit() self.run_sql(QUERY_1, client_id="client_id_1", username="admin") self.run_sql(QUERY_2, client_id="client_id_2", username="admin") self.run_sql(QUERY_3, client_id="client_id_3", username="gamma_sqllab") diff --git a/tests/integration_tests/superset_test_config.py b/tests/integration_tests/superset_test_config.py index 04472bfc24647..0935714c54275 100644 --- a/tests/integration_tests/superset_test_config.py +++ b/tests/integration_tests/superset_test_config.py @@ -95,6 +95,7 @@ def GET_FEATURE_FLAGS_FUNC(ff): FAB_ROLES = {"TestRole": [["Security", "menu_access"], ["List Users", "menu_access"]]} +PUBLIC_ROLE_LIKE = "Gamma" AUTH_ROLE_PUBLIC = "Public" EMAIL_NOTIFICATIONS = False REDIS_HOST = os.environ.get("REDIS_HOST", "localhost") # noqa: F405 diff --git a/tests/integration_tests/tags/dao_tests.py b/tests/integration_tests/tags/dao_tests.py index b06e22054ec63..8a6ba6e5f4b3a 100644 --- a/tests/integration_tests/tags/dao_tests.py +++ b/tests/integration_tests/tags/dao_tests.py @@ -18,7 +18,6 @@ from operator import and_ from unittest.mock import patch # noqa: F401 import pytest -from superset.daos.exceptions import DAOCreateFailedError, DAOException # noqa: F401 from superset.models.slice import Slice from superset.models.sql_lab import SavedQuery # noqa: F401 from superset.daos.tag import TagDAO @@ -188,6 +187,7 @@ def test_get_objects_from_tag(self): TaggedObject.object_type == ObjectType.chart, ), ) + .join(Tag, TaggedObject.tag_id == Tag.id) .distinct(Slice.id) .count() ) @@ -200,6 +200,7 @@ def test_get_objects_from_tag(self): TaggedObject.object_type == ObjectType.dashboard, ), ) + .join(Tag, TaggedObject.tag_id == Tag.id) .distinct(Dashboard.id) .count() + num_charts diff --git a/tests/unit_tests/commands/databases/create_test.py b/tests/unit_tests/commands/databases/create_test.py index 405238827d5cf..09d5744afd53b 100644 --- a/tests/unit_tests/commands/databases/create_test.py +++ b/tests/unit_tests/commands/databases/create_test.py @@ -29,7 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database with catalogs and schemas. """ - mocker.patch("superset.commands.database.create.db") mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand") database = mocker.MagicMock() @@ -53,7 +52,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database without catalogs. """ - mocker.patch("superset.commands.database.create.db") mocker.patch("superset.commands.database.create.TestConnectionDatabaseCommand") database = mocker.MagicMock() diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py index 300efb62e7d3c..37500d521420a 100644 --- a/tests/unit_tests/commands/databases/update_test.py +++ b/tests/unit_tests/commands/databases/update_test.py @@ -29,8 +29,6 @@ def database_with_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database with catalogs and schemas. """ - mocker.patch("superset.commands.database.update.db") - database = mocker.MagicMock() database.database_name = "my_db" database.db_engine_spec.__name__ = "test_engine" @@ -50,8 +48,6 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock: """ Mock a database without catalogs. """ - mocker.patch("superset.commands.database.update.db") - database = mocker.MagicMock() database.database_name = "my_db" database.db_engine_spec.__name__ = "test_engine" diff --git a/tests/unit_tests/dao/tag_test.py b/tests/unit_tests/dao/tag_test.py index d50e7d8a28e08..7662393d4fc49 100644 --- a/tests/unit_tests/dao/tag_test.py +++ b/tests/unit_tests/dao/tag_test.py @@ -22,7 +22,6 @@ def test_user_favorite_tag(mocker): from superset.daos.tag import TagDAO # Mock the behavior of TagDAO and g - mock_session = mocker.patch("superset.daos.tag.db.session") mock_TagDAO = mocker.patch( "superset.daos.tag.TagDAO" ) # Replace with the actual path to TagDAO @@ -40,14 +39,11 @@ def test_user_favorite_tag(mocker): # Check that users_favorited was updated correctly assert mock_TagDAO.find_by_id().users_favorited == [mock_g.user] - mock_session.commit.assert_called_once() - def test_remove_user_favorite_tag(mocker): from superset.daos.tag import TagDAO # Mock the behavior of TagDAO and g - mock_session = mocker.patch("superset.daos.tag.db.session") mock_TagDAO = mocker.patch("superset.daos.tag.TagDAO") mock_tag = mocker.MagicMock(users_favorited=[]) mock_TagDAO.find_by_id.return_value = mock_tag @@ -68,9 +64,6 @@ def test_remove_user_favorite_tag(mocker): # Check that users_favorited no longer contains the user assert mock_user not in mock_tag.users_favorited - # Check that the db.session.was committed - mock_session.commit.assert_called_once() - def test_remove_user_favorite_tag_no_user(mocker): from superset.daos.tag import TagDAO diff --git a/tests/unit_tests/dao/user_test.py b/tests/unit_tests/dao/user_test.py index a2a74a55497cb..bf65c51121fac 100644 --- a/tests/unit_tests/dao/user_test.py +++ b/tests/unit_tests/dao/user_test.py @@ -90,4 +90,3 @@ def test_set_avatar_url_without_existing_attributes(mock_db_session): assert len(user.extra_attributes) == 1 assert user.extra_attributes[0].avatar_url == new_url mock_db_session.add.assert_called() # New attribute should be added - mock_db_session.commit.assert_called() diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index 488378f7ca3c2..f4534d216b9b7 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -115,7 +115,7 @@ def test_post_with_uuid( payload = response.json assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb" - database = db.session.query(Database).one() + database = session.query(Database).one() assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb") diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py index b20578784d034..9b9393d3a7359 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -36,7 +36,7 @@ def test_create_ssh_tunnel_command() -> None: ) properties = { - "database_id": database.id, + "database": database, "server_address": "123.132.123.1", "server_port": "3005", "username": "foo", diff --git a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py index 1456f7fd801cd..a24a94ec36d13 100644 --- a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py +++ b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py @@ -31,7 +31,6 @@ def test_create_ssh_tunnel(): "username": "foo", "password": "bar", }, - commit=False, ) assert result is not None diff --git a/tests/unit_tests/security/manager_test.py b/tests/unit_tests/security/manager_test.py index a0ec87c52eb95..e35513f2ff8a1 100644 --- a/tests/unit_tests/security/manager_test.py +++ b/tests/unit_tests/security/manager_test.py @@ -413,7 +413,6 @@ def test_raise_for_access_chart_owner( owners=[alpha], ) session.add(slice) - session.flush() with override_user(alpha): sm.raise_for_access( diff --git a/tests/unit_tests/utils/lock_tests.py b/tests/unit_tests/utils/lock_tests.py index aa231bb0cf8f2..4c9121fe38744 100644 --- a/tests/unit_tests/utils/lock_tests.py +++ b/tests/unit_tests/utils/lock_tests.py @@ -22,8 +22,8 @@ import pytest from freezegun import freeze_time -from sqlalchemy.orm import Session, sessionmaker +from superset import db from superset.exceptions import CreateKeyValueDistributedLockFailedException from superset.key_value.types import JsonKeyValueCodec from superset.utils.lock import get_key, KeyValueDistributedLock @@ -32,56 +32,51 @@ OTHER_KEY = get_key("ns2", a=1, b=2) -def _get_lock(key: UUID, session: Session) -> Any: +def _get_lock(key: UUID) -> Any: from superset.key_value.models import KeyValueEntry - entry = session.query(KeyValueEntry).filter_by(uuid=key).first() + entry = db.session.query(KeyValueEntry).filter_by(uuid=key).first() if entry is None or entry.is_expired(): return None return JsonKeyValueCodec().decode(entry.value) -def _get_other_session() -> Session: - # This session is used to simulate what another worker will find in the metastore - # during the locking process. - from superset import db - - bind = db.session.get_bind() - SessionMaker = sessionmaker(bind=bind) - return SessionMaker() - - def test_key_value_distributed_lock_happy_path() -> None: """ Test successfully acquiring and returning the distributed lock. + + Note we use a nested transaction to ensure that the cleanup from the outer context + manager is correctly invoked, otherwise a partial rollback would occur leaving the + database in a fractured state. """ - session = _get_other_session() with freeze_time("2021-01-01"): - assert _get_lock(MAIN_KEY, session) is None + assert _get_lock(MAIN_KEY) is None + with KeyValueDistributedLock("ns", a=1, b=2) as key: assert key == MAIN_KEY - assert _get_lock(key, session) is True - assert _get_lock(OTHER_KEY, session) is None - with pytest.raises(CreateKeyValueDistributedLockFailedException): - with KeyValueDistributedLock("ns", a=1, b=2): - pass + assert _get_lock(key) is True + assert _get_lock(OTHER_KEY) is None + + with db.session.begin_nested(): + with pytest.raises(CreateKeyValueDistributedLockFailedException): + with KeyValueDistributedLock("ns", a=1, b=2): + pass - assert _get_lock(MAIN_KEY, session) is None + assert _get_lock(MAIN_KEY) is None def test_key_value_distributed_lock_expired() -> None: """ Test expiration of the distributed lock """ - session = _get_other_session() - with freeze_time("2021-01-01T"): - assert _get_lock(MAIN_KEY, session) is None + with freeze_time("2021-01-01"): + assert _get_lock(MAIN_KEY) is None with KeyValueDistributedLock("ns", a=1, b=2): - assert _get_lock(MAIN_KEY, session) is True - with freeze_time("2022-01-01T"): - assert _get_lock(MAIN_KEY, session) is None + assert _get_lock(MAIN_KEY) is True + with freeze_time("2022-01-01"): + assert _get_lock(MAIN_KEY) is None - assert _get_lock(MAIN_KEY, session) is None + assert _get_lock(MAIN_KEY) is None