Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catch a missing database with MySQL in the connection string #108

Merged
merged 4 commits into from
Dec 29, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion jupyterlab_sql/connection_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,21 @@


def is_sqlite(url):
backend = sqlalchemy.engine.url.make_url(url).get_backend_name()
backend = _to_sqlalchemy_url(url).get_backend_name()
return backend == "sqlite"


def is_mysql(url):
backend = _to_sqlalchemy_url(url).get_backend_name()
return backend == "mysql"


def has_database(url):
database = _to_sqlalchemy_url(url).database
# database is either None or an empty string, depending on
# whether the URL contains a trailing slash.
return bool(database)


def _to_sqlalchemy_url(url):
return sqlalchemy.engine.url.make_url(url)
12 changes: 11 additions & 1 deletion jupyterlab_sql/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@

from .serializer import make_row_serializable
from .cache import Cache
from .connection_url import is_sqlite
from .connection_url import is_sqlite, is_mysql, has_database


class InvalidConnectionUrl(Exception):
pass


class QueryResult:
Expand All @@ -28,6 +32,12 @@ def __init__(self):
self._sqlite_engine_cache = Cache()

def get_table_names(self, connection_url):
if is_mysql(connection_url) and not has_database(connection_url):
raise InvalidConnectionUrl(
"You need to specify a database name in the connection "
"URL for MySQL databases. Use, for instance, "
"`mysql://localhost/employees`."
)
engine = self._get_engine(connection_url)
return engine.table_names()

Expand Down
35 changes: 35 additions & 0 deletions jupyterlab_sql/tests/test_connection_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,38 @@ def test_sqlite(url):
)
def test_not_sqlite(url):
assert not connection_url.is_sqlite(url)


@pytest.mark.parametrize(
"url",
[
"mysql:///employees",
"mysql://localhost/employees",
"mysql+mysqldb:///employees",
"mysql+pymysql:///employees",
],
)
def test_mysql(url):
assert connection_url.is_mysql(url)


@pytest.mark.parametrize(
"url", ["postgres://localhost:5432/postgres", "sqlite://"]
)
def test_not_mysql(url):
assert not connection_url.is_mysql(url)


@pytest.mark.parametrize(
"url",
["mysql:///employees", "mysql://localhost/employees", "sqlite:///foo.db"],
)
def test_has_database(url):
assert connection_url.has_database(url)


@pytest.mark.parametrize(
"url", ["mysql://", "mysql://localhost/", "sqlite://"]
)
def test_not_has_database(url):
assert not connection_url.has_database(url)