Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: test connection raises only command exceptions #12307

Merged
merged 5 commits into from
Jan 8, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 4 additions & 33 deletions superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,8 @@
from flask import g, request, Response, send_file
from flask_appbuilder.api import expose, protect, rison, safe
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_babel import gettext as _
from marshmallow import ValidationError
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import (
DBAPIError,
NoSuchModuleError,
NoSuchTableError,
OperationalError,
SQLAlchemyError,
)
from sqlalchemy.exc import NoSuchTableError, OperationalError, SQLAlchemyError

from superset import event_logger
from superset.commands.exceptions import CommandInvalidError
Expand All @@ -49,7 +41,7 @@
DatabaseImportError,
DatabaseInvalidError,
DatabaseNotFoundError,
DatabaseSecurityUnsafeError,
DatabaseTestConnectionFailedError,
DatabaseUpdateFailedError,
)
from superset.databases.commands.export import ExportDatabasesCommand
Expand Down Expand Up @@ -589,29 +581,8 @@ def test_connection( # pylint: disable=too-many-return-statements
try:
TestConnectionDatabaseCommand(g.user, item).run()
return self.response(200, message="OK")
except (NoSuchModuleError, ModuleNotFoundError):
logger.info("Invalid driver")
driver_name = make_url(item.get("sqlalchemy_uri")).drivername
return self.response(
400,
message=_("Could not load database driver: {}").format(driver_name),
driver_name=driver_name,
)
except DatabaseSecurityUnsafeError as ex:
return self.response_422(message=ex)
except DBAPIError:
logger.warning("Connection failed")
return self.response(
500,
message=_("Connection failed, please check your connection settings"),
)
except Exception as ex: # pylint: disable=broad-except
logger.error("Unexpected error %s", type(ex).__name__)
return self.response_400(
message=_(
"Unexpected error occurred, please check your logs for details"
)
)
except DatabaseTestConnectionFailedError as ex:
return self.response_422(message=str(ex))

@expose("/<int:pk>/related_objects/", methods=["GET"])
@protect()
Expand Down
17 changes: 14 additions & 3 deletions superset/databases/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
ImportFailedError,
UpdateFailedError,
)
from superset.security.analytics_db_safety import DBSecurityException


class DatabaseInvalidError(CommandInvalidError):
Expand Down Expand Up @@ -102,7 +101,7 @@ class DatabaseUpdateFailedError(UpdateFailedError):
class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors
DatabaseCreateFailedError, DatabaseUpdateFailedError,
):
message = _("Could not connect to database.")
message = _("Connection failed, please check your connection settings")


class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError):
Expand All @@ -117,9 +116,21 @@ class DatabaseDeleteFailedReportsExistError(DatabaseDeleteFailedError):
message = _("There are associated alerts or reports")


class DatabaseSecurityUnsafeError(DBSecurityException):
class DatabaseTestConnectionFailedError(CommandException):
message = _("Connection failed, please check your connection settings")


class DatabaseSecurityUnsafeError(DatabaseTestConnectionFailedError):
message = _("Stopped an unsafe database connection")


class DatabaseTestConnectionDriverError(DatabaseTestConnectionFailedError):
message = _("Could not load database driver")


class DatabaseTestConnectionUnexpectedError(DatabaseTestConnectionFailedError):
message = _("Unexpected error occurred, please check your logs for details")


class DatabaseImportError(ImportFailedError):
message = _("Import database failed for an unknown reason")
30 changes: 22 additions & 8 deletions superset/databases/commands/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@
from typing import Any, Dict, Optional

from flask_appbuilder.security.sqla.models import User
from sqlalchemy.exc import DBAPIError
from flask_babel import gettext as _
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import DBAPIError, NoSuchModuleError

from superset.commands.base import BaseCommand
from superset.databases.commands.exceptions import DatabaseSecurityUnsafeError
from superset.databases.commands.exceptions import (
DatabaseSecurityUnsafeError,
DatabaseTestConnectionDriverError,
DatabaseTestConnectionFailedError,
DatabaseTestConnectionUnexpectedError,
)
from superset.databases.dao import DatabaseDAO
from superset.models.core import Database
from superset.security.analytics_db_safety import DBSecurityException
Expand All @@ -38,11 +45,10 @@ def __init__(self, user: User, data: Dict[str, Any]):

def run(self) -> None:
self.validate()
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted
try:
uri = self._properties.get("sqlalchemy_uri", "")
if self._model and uri == self._model.safe_sqlalchemy_uri():
uri = self._model.sqlalchemy_uri_decrypted

database = DatabaseDAO.build_db_for_connection_test(
server_cert=self._properties.get("server_cert", ""),
extra=self._properties.get("extra", "{}"),
Expand All @@ -57,9 +63,17 @@ def run(self) -> None:
with closing(engine.raw_connection()) as conn:
if not engine.dialect.do_ping(conn):
raise DBAPIError(None, None, None)
except (NoSuchModuleError, ModuleNotFoundError):
driver_name = make_url(uri).drivername
raise DatabaseTestConnectionDriverError(
message=_("Could not load database driver: {}").format(driver_name),
)
except DBAPIError:
raise DatabaseTestConnectionFailedError()
except DBSecurityException as ex:
logger.warning(ex)
raise DatabaseSecurityUnsafeError()
raise DatabaseSecurityUnsafeError(message=str(ex))
except Exception:
raise DatabaseTestConnectionUnexpectedError()

def validate(self) -> None:
database_name = self._properties.get("database_name")
Expand Down
6 changes: 4 additions & 2 deletions superset/security/analytics_db_safety.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
from sqlalchemy.engine.url import URL

from superset.exceptions import SupersetException

class DBSecurityException(Exception):
""" Exception to prevent a security issue with connecting a DB """

class DBSecurityException(SupersetException):
""" Exception to prevent a security issue with connecting to a DB """

status = 400

Expand Down
14 changes: 8 additions & 6 deletions tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def test_create_database_conn_fail(self):
self.login(username="admin")
response = self.client.post(uri, json=database_data)
response_data = json.loads(response.data.decode("utf-8"))
expected_response = {"message": "Could not connect to database."}
expected_response = {
"message": "Connection failed, please check your connection settings"
}
self.assertEqual(response.status_code, 422)
self.assertEqual(response_data, expected_response)

Expand Down Expand Up @@ -431,7 +433,9 @@ def test_update_database_conn_fail(self):
self.login(username="admin")
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": "Could not connect to database."}
expected_response = {
"message": "Connection failed, please check your connection settings"
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
# Cleanup
Expand Down Expand Up @@ -787,11 +791,10 @@ def test_test_connection_failed(self):
}
url = "api/v1/database/test_connection"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"driver_name": "broken",
"message": "Could not load database driver: broken",
}
self.assertEqual(response, expected_response)
Expand All @@ -803,11 +806,10 @@ def test_test_connection_failed(self):
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"driver_name": "mssql+pymssql",
"message": "Could not load database driver: mssql+pymssql",
}
self.assertEqual(response, expected_response)
Expand Down