Skip to content

Commit

Permalink
refactor: test connection raises only command exceptions (#12307)
Browse files Browse the repository at this point in the history
* refactor: test connection raises only command exceptions

* fix tests

* fix tests

* fix tests

* lint fix
  • Loading branch information
dpgaspar authored Jan 8, 2021
1 parent fecfc34 commit c685c9e
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 52 deletions.
37 changes: 4 additions & 33 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,8 @@
from flask import g, request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import gettext as _
from marshmallow import ValidationError
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import (
DBAPIError,
NoSuchModuleError,
NoSuchTableError,
OperationalError,
SQLAlchemyError,
)
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError

from superset import event_logger
from superset.commands.exceptions import CommandInvalidError
Expand All @@ -49,7 +41,7 @@
DatabaseImportError,
DatabaseInvalidError,
DatabaseNotFoundError,
DatabaseSecurityUnsafeError,
DatabaseTestConnectionFailedError,
DatabaseUpdateFailedError,
)
from superset.databases.commands.export import ExportDatabasesCommand
Expand Down Expand Up @@ -589,29 +581,8 @@ def test_connection( # pylint: disable=too-many-return-statements
try:
TestConnectionDatabaseCommand(g.user, item).run()
return self.response(200, message="OK")
except (NoSuchModuleError, ModuleNotFoundError):
logger.info("Invalid driver")
driver_name = make_url(item.get("sqlalchemy_uri")).drivername
return self.response(
400,
message=_("Could not load database driver: {}").format(driver_name),
driver_name=driver_name,
)
except DatabaseSecurityUnsafeError as ex:
return self.response_422(message=ex)
except DBAPIError:
logger.warning("Connection failed")
return self.response(
500,
message=_("Connection failed, please check your connection settings"),
)
except Exception as ex: # pylint: disable=broad-except
logger.error("Unexpected error %s", type(ex).__name__)
return self.response_400(
message=_(
"Unexpected error occurred, please check your logs for details"
)
)
except DatabaseTestConnectionFailedError as ex:
return self.response_422(message=str(ex))

@expose("/<int:pk>/related_objects/", methods=["GET"])
@protect()
Expand Down
17 changes: 14 additions & 3 deletions superset/databases/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
ImportFailedError,
UpdateFailedError,
)
from superset.security.analytics_db_safety import DBSecurityException


class DatabaseInvalidError(CommandInvalidError):
Expand Down Expand Up @@ -102,7 +101,7 @@ class DatabaseUpdateFailedError(UpdateFailedError):
class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors
DatabaseCreateFailedError, DatabaseUpdateFailedError,
):
message = _("Could not connect to database.")
message = _("Connection failed, please check your connection settings")


class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):
Expand All @@ -117,9 +116,21 @@ class DatabaseDeleteFailedReportsExistError(DatabaseDeleteFailedError):
message = _("There are associated alerts or reports")


class DatabaseSecurityUnsafeError(DBSecurityException):
class DatabaseTestConnectionFailedError(CommandException):
message = _("Connection failed, please check your connection settings")


class DatabaseSecurityUnsafeError(DatabaseTestConnectionFailedError):
message = _("Stopped an unsafe database connection")


class DatabaseTestConnectionDriverError(DatabaseTestConnectionFailedError):
message = _("Could not load database driver")


class DatabaseTestConnectionUnexpectedError(DatabaseTestConnectionFailedError):
message = _("Unexpected error occurred, please check your logs for details")


class DatabaseImportError(ImportFailedError):
message = _("Import database failed for an unknown reason")
30 changes: 22 additions & 8 deletions superset/databases/commands/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@
from typing import Any, Dict, Optional

from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import DBAPIError
from flask_babel import gettext as _
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import DBAPIError, NoSuchModuleError

from superset.commands.base import BaseCommand
from superset.databases.commands.exceptions import DatabaseSecurityUnsafeError
from superset.databases.commands.exceptions import (
DatabaseSecurityUnsafeError,
DatabaseTestConnectionDriverError,
DatabaseTestConnectionFailedError,
DatabaseTestConnectionUnexpectedError,
)
from superset.databases.dao import DatabaseDAO
from superset.models.core import Database
from superset.security.analytics_db_safety import DBSecurityException
Expand All @@ -38,11 +45,10 @@ def __init__(self, user: User, data: Dict[str, Any]):

def run(self) -> None:
self.validate()
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted
try:
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted

database = DatabaseDAO.build_db_for_connection_test(
server_cert=self._properties.get("server_cert", ""),
extra=self._properties.get("extra", "{}"),
Expand All @@ -57,9 +63,17 @@ def run(self) -> None:
with closing(engine.raw_connection()) as conn:
if not engine.dialect.do_ping(conn):
raise DBAPIError(None, None, None)
except (NoSuchModuleError, ModuleNotFoundError):
driver_name = make_url(uri).drivername
raise DatabaseTestConnectionDriverError(
message=_("Could not load database driver: {}").format(driver_name),
)
except DBAPIError:
raise DatabaseTestConnectionFailedError()
except DBSecurityException as ex:
logger.warning(ex)
raise DatabaseSecurityUnsafeError()
raise DatabaseSecurityUnsafeError(message=str(ex))
except Exception:
raise DatabaseTestConnectionUnexpectedError()

def validate(self) -> None:
database_name = self._properties.get("database_name")
Expand Down
6 changes: 4 additions & 2 deletions superset/security/analytics_db_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
from sqlalchemy.engine.url import URL

from superset.exceptions import SupersetException

class DBSecurityException(Exception):
""" Exception to prevent a security issue with connecting a DB """

class DBSecurityException(SupersetException):
""" Exception to prevent a security issue with connecting to a DB """

status = 400

Expand Down
14 changes: 8 additions & 6 deletions tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def test_create_database_conn_fail(self):
self.login(username="admin")
response = self.client.post(uri, json=database_data)
response_data = json.loads(response.data.decode("utf-8"))
expected_response = {"message": "Could not connect to database."}
expected_response = {
"message": "Connection failed, please check your connection settings"
}
self.assertEqual(response.status_code, 422)
self.assertEqual(response_data, expected_response)

Expand Down Expand Up @@ -431,7 +433,9 @@ def test_update_database_conn_fail(self):
self.login(username="admin")
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": "Could not connect to database."}
expected_response = {
"message": "Connection failed, please check your connection settings"
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
# Cleanup
Expand Down Expand Up @@ -787,11 +791,10 @@ def test_test_connection_failed(self):
}
url = "api/v1/database/test_connection"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"driver_name": "broken",
"message": "Could not load database driver: broken",
}
self.assertEqual(response, expected_response)
Expand All @@ -803,11 +806,10 @@ def test_test_connection_failed(self):
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"driver_name": "mssql+pymssql",
"message": "Could not load database driver: mssql+pymssql",
}
self.assertEqual(response, expected_response)
Expand Down

0 comments on commit c685c9e

Please sign in to comment.