diff --git a/superset/models.py b/superset/models.py index 4812a2408c3d3..a722e12e09ad6 100644 --- a/superset/models.py +++ b/superset/models.py @@ -776,15 +776,25 @@ def get_sqla_engine(self, schema=None): extra = self.get_extra() url = make_url(self.sqlalchemy_uri_decrypted) params = extra.get('engine_params', {}) - if self.backend == 'presto' and schema: - if '/' in url.database: - url.database = url.database.split('/')[0] + '/' + schema - else: - url.database += '/' + schema - elif schema: - url.database = schema + url.database = self.get_database_for_various_backend(url, schema) return create_engine(url, **params) + def get_database_for_various_backend(self, uri, default_database=None): + database = uri.database + if self.backend == 'presto' and default_database: + if '/' in database: + database = database.split('/')[0] + '/' + default_database + else: + database += '/' + default_database + # Postgres and Redshift use the concept of schema as a logical entity + # on top of the database, so the database should not be changed + # even if passed default_database + elif self.backend == 'redshift' or self.backend == 'postgresql': + pass + elif default_database: + database = default_database + return database + def get_reserved_words(self): return self.get_sqla_engine().dialect.preparer.reserved_words diff --git a/tests/model_tests.py b/tests/model_tests.py new file mode 100644 index 0000000000000..c9e60178abdb5 --- /dev/null +++ b/tests/model_tests.py @@ -0,0 +1,48 @@ +import unittest + +from sqlalchemy.engine.url import make_url + +from superset.models import Database + + +class DatabaseModelTestCase(unittest.TestCase): + def test_database_for_various_backend(self): + sqlalchemy_uri = 'presto://presto.airbnb.io:8080/hive/default' + model = Database(sqlalchemy_uri=sqlalchemy_uri) + url = make_url(model.sqlalchemy_uri) + db = model.get_database_for_various_backend(url, None) + assert db == 'hive/default' + db = model.get_database_for_various_backend(url, 'raw_data') + assert db == 'hive/raw_data' + + sqlalchemy_uri = 'redshift+psycopg2://superset:XXXXXXXXXX@redshift.airbnb.io:5439/prod' + model = Database(sqlalchemy_uri=sqlalchemy_uri) + url = make_url(model.sqlalchemy_uri) + db = model.get_database_for_various_backend(url, None) + assert db == 'prod' + db = model.get_database_for_various_backend(url, 'test') + assert db == 'prod' + + sqlalchemy_uri = 'postgresql+psycopg2://superset:XXXXXXXXXX@postgres.airbnb.io:5439/prod' + model = Database(sqlalchemy_uri=sqlalchemy_uri) + url = make_url(model.sqlalchemy_uri) + db = model.get_database_for_various_backend(url, None) + assert db == 'prod' + db = model.get_database_for_various_backend(url, 'adhoc') + assert db == 'prod' + + sqlalchemy_uri = 'hive://hive@hive.airbnb.io:10000/raw_data' + model = Database(sqlalchemy_uri=sqlalchemy_uri) + url = make_url(model.sqlalchemy_uri) + db = model.get_database_for_various_backend(url, None) + assert db == 'raw_data' + db = model.get_database_for_various_backend(url, 'adhoc') + assert db == 'adhoc' + + sqlalchemy_uri = 'mysql://superset:XXXXXXXXXX@mysql.airbnb.io/superset' + model = Database(sqlalchemy_uri=sqlalchemy_uri) + url = make_url(model.sqlalchemy_uri) + db = model.get_database_for_various_backend(url, None) + assert db == 'superset' + db = model.get_database_for_various_backend(url, 'adhoc') + assert db == 'adhoc'