Skip to content

Commit

Permalink
Fixing confusion when selecting schema across engines (#2572)
Browse files Browse the repository at this point in the history
  • Loading branch information
mistercrunch authored Apr 10, 2017
1 parent 40b3d3b commit ac84fc2
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 53 deletions.
38 changes: 38 additions & 0 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,27 @@ def extract_error_message(cls, e):
"""Extract error message for queries"""
return utils.error_msg_from_exception(e)

@classmethod
def adjust_database_uri(cls, uri, selected_schema):
"""Based on a URI and selected schema, return a new URI
The URI here represents the URI as entered when saving the database,
``selected_schema`` is the schema currently active presumably in
the SQL Lab dropdown. Based on that, for some database engine,
we can return a new altered URI that connects straight to the
active schema, meaning the users won't have to prefix the object
names by the schema name.
Some databases engines have 2 level of namespacing: database and
schema (postgres, oracle, mssql, ...)
For those it's probably better to not alter the database
component of the URI with the schema name, it won't work.
Some database drivers like presto accept "{catalog}/{schema}" in
the database component of the URL, that can be handled here.
"""
return uri

@classmethod
def sql_preprocessor(cls, sql):
"""If the SQL needs to be altered prior to running it
Expand Down Expand Up @@ -290,6 +311,12 @@ def convert_dttm(cls, target_type, dttm):
dttm.strftime('%Y-%m-%d %H:%M:%S'))
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))

@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
if selected_schema:
uri.database = selected_schema
return uri

@classmethod
def epoch_to_dttm(cls):
return "from_unixtime({col})"
Expand Down Expand Up @@ -328,6 +355,17 @@ def patch(cls):
from superset.db_engines import presto as patched_presto
presto.Cursor.cancel = patched_presto.cancel

@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
database = uri.database
if selected_schema:
if '/' in database:
database = database.split('/')[0] + '/' + selected_schema
else:
database += '/' + selected_schema
uri.database = database
return uri

@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
Expand Down
25 changes: 4 additions & 21 deletions superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,26 +560,10 @@ def set_sqlalchemy_uri(self, uri):

def get_sqla_engine(self, schema=None):
extra = self.get_extra()
url = make_url(self.sqlalchemy_uri_decrypted)
uri = make_url(self.sqlalchemy_uri_decrypted)
params = extra.get('engine_params', {})
url.database = self.get_database_for_various_backend(url, schema)
return create_engine(url, **params)

def get_database_for_various_backend(self, uri, default_database=None):
database = uri.database
if self.backend == 'presto' and default_database:
if '/' in database:
database = database.split('/')[0] + '/' + default_database
else:
database += '/' + default_database
# Postgres and Redshift use the concept of schema as a logical entity
# on top of the database, so the database should not be changed
# even if passed default_database
elif self.backend in ('redshift', 'postgresql', 'sqlite'):
pass
elif default_database:
database = default_database
return database
uri = self.db_engine_spec.adjust_database_uri(uri, schema)
return create_engine(uri, **params)

def get_reserved_words(self):
return self.get_sqla_engine().dialect.preparer.reserved_words
Expand Down Expand Up @@ -662,9 +646,8 @@ def all_schema_names(self):

@property
def db_engine_spec(self):
engine_name = self.get_sqla_engine().name or 'base'
return db_engine_specs.engines.get(
engine_name, db_engine_specs.BaseEngineSpec)
self.backend, db_engine_specs.BaseEngineSpec)

def grains(self):
"""Defines time granularity database-specific expressions.
Expand Down
5 changes: 3 additions & 2 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ def generate_download_headers(extension):
class DatabaseView(SupersetModelView, DeleteMixin): # noqa
datamodel = SQLAInterface(models.Database)
list_columns = [
'verbose_name', 'backend', 'allow_run_sync', 'allow_run_async',
'allow_dml', 'creator', 'changed_on_', 'database_name']
'database_name', 'backend', 'allow_run_sync', 'allow_run_async',
'allow_dml', 'creator', 'modified']
add_columns = [
'database_name', 'sqlalchemy_uri', 'cache_timeout', 'extra',
'expose_in_sqllab', 'allow_run_sync', 'allow_run_async',
Expand Down Expand Up @@ -1351,6 +1351,7 @@ def testconn(self):
engine.connect()
return json.dumps(engine.table_names(), indent=4)
except Exception as e:
logging.exception(e)
return json_error_response((
"Connection failed!\n\n"
"The error message returned was:\n{}").format(e))
Expand Down
68 changes: 38 additions & 30 deletions tests/model_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,43 +6,51 @@


class DatabaseModelTestCase(unittest.TestCase):
def test_database_for_various_backend(self):

def test_database_schema_presto(self):
sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'hive/default'
db = model.get_database_for_various_backend(url, 'raw_data')
assert db == 'hive/raw_data'

sqlalchemy_uri = 'redshift+psycopg2://superset:[email protected]:5439/prod'
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('hive/default', db)

db = make_url(model.get_sqla_engine(schema='core_db').url).database
self.assertEquals('hive/core_db', db)

sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'prod'
db = model.get_database_for_various_backend(url, 'test')
assert db == 'prod'

sqlalchemy_uri = 'postgresql+psycopg2://superset:[email protected]:5439/prod'
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('hive', db)

db = make_url(model.get_sqla_engine(schema='core_db').url).database
self.assertEquals('hive/core_db', db)

def test_database_schema_postgres(self):
sqlalchemy_uri = 'postgresql+psycopg2://postgres.airbnb.io:5439/prod'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'prod'
db = model.get_database_for_various_backend(url, 'adhoc')
assert db == 'prod'

sqlalchemy_uri = 'hive://[email protected]:10000/raw_data'
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('prod', db)

db = make_url(model.get_sqla_engine(schema='foo').url).database
self.assertEquals('prod', db)

def test_database_schema_hive(self):
sqlalchemy_uri = 'hive://[email protected]:10000/hive/default'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'raw_data'
db = model.get_database_for_various_backend(url, 'adhoc')
assert db == 'adhoc'
db = make_url(model.get_sqla_engine().url).database
self.assertEquals('hive/default', db)

db = make_url(model.get_sqla_engine(schema='core_db').url).database
self.assertEquals('hive/core_db', db)

sqlalchemy_uri = 'mysql://superset:[email protected]/superset'
def test_database_schema_mysql(self):
sqlalchemy_uri = 'mysql://root@localhost/superset'
model = Database(sqlalchemy_uri=sqlalchemy_uri)
url = make_url(model.sqlalchemy_uri)
db = model.get_database_for_various_backend(url, None)
assert db == 'superset'
db = model.get_database_for_various_backend(url, 'adhoc')
assert db == 'adhoc'

db = make_url(model.get_sqla_engine().url).database
self.assertEquals('superset', db)

db = make_url(model.get_sqla_engine(schema='staging').url).database
self.assertEquals('staging', db)

0 comments on commit ac84fc2

Please sign in to comment.