diff --git a/superset/models.py b/superset/models.py index a8e3fdf6dc0e8..bf2e68917ef8b 100644 --- a/superset/models.py +++ b/superset/models.py @@ -665,6 +665,7 @@ class Database(Model, AuditMixinNullable): """An ORM object that stores Database related information""" __tablename__ = 'dbs' + type = "table" id = Column(Integer, primary_key=True) database_name = Column(String(250), unique=True) @@ -1524,6 +1525,7 @@ class DruidCluster(Model, AuditMixinNullable): """ORM object referencing the Druid clusters""" __tablename__ = 'clusters' + type = "druid" id = Column(Integer, primary_key=True) cluster_name = Column(String(250), unique=True) diff --git a/superset/source_registry.py b/superset/source_registry.py index 0705460c646cd..2c72157ebf0b1 100644 --- a/superset/source_registry.py +++ b/superset/source_registry.py @@ -40,6 +40,27 @@ def get_datasource_by_name(cls, session, datasource_type, datasource_name, d.name == datasource_name and schema == schema] return db_ds[0] + @classmethod + def query_datasources_by_name( + cls, session, database, datasource_name, schema=None): + datasource_class = SourceRegistry.sources[database.type] + if database.type == 'table': + query = ( + session.query(datasource_class) + .filter_by(database_id=database.id) + .filter_by(table_name=datasource_name)) + if schema: + query = query.filter_by(schema=schema) + return query.all() + if database.type == 'druid': + return ( + session.query(datasource_class) + .filter_by(cluster_name=database.id) + .filter_by(datasource_name=datasource_name) + .all() + ) + return None + @classmethod def get_eager_datasource(cls, session, datasource_type, datasource_id): """Returns datasource with columns and metrics.""" diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 5ee65468190fa..0e5b84c3fedcd 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -11,7 +11,7 @@ from sqlalchemy.orm import sessionmaker from superset import ( - app, db, models, utils, dataframe, results_backend) + app, db, models, utils, dataframe, results_backend, sql_parse, sm) from superset.db_engine_specs import LimitMethod from superset.jinja_context import get_template_processor QueryStatus = models.QueryStatus @@ -19,16 +19,12 @@ celery_app = celery.Celery(config_source=app.config.get('CELERY_CONFIG')) -def is_query_select(sql): - return sql.upper().startswith('SELECT') - - def create_table_as(sql, table_name, schema=None, override=False): """Reformats the query into the create table as query. Works only for the single select SQL statements, in all other cases the sql query is not modified. - :param sql: string, sql query that will be executed + :param superset_query: string, sql query that will be executed :param table_name: string, will contain the results of the query execution :param override, boolean, table table_name will be dropped if true :return: string, create table as query @@ -41,12 +37,9 @@ def create_table_as(sql, table_name, schema=None, override=False): if schema: table_name = schema + '.' + table_name exec_sql = '' - if is_query_select(sql): - if override: - exec_sql = 'DROP TABLE IF EXISTS {table_name};\n' - exec_sql += "CREATE TABLE {table_name} AS \n{sql}" - else: - raise Exception("Could not generate CREATE TABLE statement") + if override: + exec_sql = 'DROP TABLE IF EXISTS {table_name};\n' + exec_sql += "CREATE TABLE {table_name} AS \n{sql}" return exec_sql.format(**locals()) @@ -76,12 +69,12 @@ def handle_error(msg): raise Exception(query.error_message) # Limit enforced only for retrieving the data, not for the CTA queries. - is_select = is_query_select(executed_sql); - if not is_select and not database.allow_dml: + superset_query = sql_parse.SupersetQuery(executed_sql) + if not superset_query.is_select() and not database.allow_dml: handle_error( "Only `SELECT` statements are allowed against this database") if query.select_as_cta: - if not is_select: + if not superset_query.is_select(): handle_error( "Only `SELECT` statements can be used with the CREATE TABLE " "feature.") @@ -94,7 +87,7 @@ def handle_error(msg): executed_sql, query.tmp_table_name, database.force_ctas_schema) query.select_as_cta_used = True elif ( - query.limit and is_select and + query.limit and superset_query.is_select() and db_engine_spec.limit_method == LimitMethod.WRAP_SQL): executed_sql = database.wrap_sql_limit(executed_sql, query.limit) query.limit_used = True diff --git a/superset/sql_parse.py b/superset/sql_parse.py new file mode 100644 index 0000000000000..8f2c6e018b8e1 --- /dev/null +++ b/superset/sql_parse.py @@ -0,0 +1,101 @@ +import sqlparse +from sqlparse.sql import IdentifierList, Identifier +from sqlparse.tokens import Keyword, Name + +RESULT_OPERATIONS = {'UNION', 'INTERSECT', 'EXCEPT'} +PRECEDES_TABLE_NAME = {'FROM', 'JOIN', 'DESC', 'DESCRIBE', 'WITH'} + + +# TODO: some sql_lab logic here. +class SupersetQuery(object): + def __init__(self, sql_statement): + self._tokens = [] + self.sql = sql_statement + self._table_names = set() + self._alias_names = set() + # TODO: multistatement support + for statement in sqlparse.parse(self.sql): + self.__extract_from_token(statement) + self._table_names = self._table_names - self._alias_names + + @property + def tables(self): + return self._table_names + + # TODO: use sqlparse for this check. + def is_select(self): + return self.sql.upper().startswith('SELECT') + + @staticmethod + def __precedes_table_name(token_value): + for keyword in PRECEDES_TABLE_NAME: + if keyword in token_value: + return True + return False + + @staticmethod + def __get_full_name(identifier): + if len(identifier.tokens) > 1 and identifier.tokens[1].value == '.': + return "{}.{}".format(identifier.tokens[0].value, + identifier.tokens[2].value) + return identifier.get_real_name() + + @staticmethod + def __is_result_operation(keyword): + for operation in RESULT_OPERATIONS: + if operation in keyword.upper(): + return True + return False + + @staticmethod + def __is_identifier(token): + return ( + isinstance(token, IdentifierList) or isinstance(token, Identifier)) + + def __process_identifier(self, identifier): + # exclude subselects + if '(' not in '{}'.format(identifier): + self._table_names.add(SupersetQuery.__get_full_name(identifier)) + return + + # store aliases + if hasattr(identifier, 'get_alias'): + self._alias_names.add(identifier.get_alias()) + if hasattr(identifier, 'tokens'): + # some aliases are not parsed properly + if identifier.tokens[0].ttype == Name: + self._alias_names.add(identifier.tokens[0].value) + self.__extract_from_token(identifier) + + def __extract_from_token(self, token): + if not hasattr(token, 'tokens'): + return + + table_name_preceding_token = False + + for item in token.tokens: + if item.is_group and not self.__is_identifier(item): + self.__extract_from_token(item) + + if item.ttype in Keyword: + if SupersetQuery.__precedes_table_name(item.value.upper()): + table_name_preceding_token = True + continue + + if not table_name_preceding_token: + continue + + if item.ttype in Keyword: + if SupersetQuery.__is_result_operation(item.value): + table_name_preceding_token = False + continue + # FROM clause is over + break + + if isinstance(item, Identifier): + self.__process_identifier(item) + + if isinstance(item, IdentifierList): + for token in item.tokens: + if SupersetQuery.__is_identifier(token): + self.__process_identifier(token) diff --git a/superset/views.py b/superset/views.py index f0458e5f84e3a..8768234374b32 100755 --- a/superset/views.py +++ b/superset/views.py @@ -36,7 +36,7 @@ import superset from superset import ( appbuilder, cache, db, models, viz, utils, app, - sm, sql_lab, results_backend, security, + sm, sql_lab, sql_parse, results_backend, security, ) from superset.source_registry import SourceRegistry from superset.models import DatasourceAccessRequest as DAR @@ -74,6 +74,18 @@ def datasource_access(self, datasource): self.can_access("datasource_access", datasource.perm) ) + def datasource_access_by_name( + self, database, datasource_name, schema=None): + if (self.database_access(database) or + self.all_datasource_access()): + return True + datasources = SourceRegistry.query_datasources_by_name( + db.session, database, datasource_name, schema=schema) + for datasource in datasources: + if self.can_access("datasource_access", datasource.perm): + return True + return False + class ListWidgetWithCheckboxes(ListWidget): """An alternative to list view that renders Boolean fields as checkboxes @@ -2303,27 +2315,45 @@ def results(self, key): @log_this def sql_json(self): """Runs arbitrary sql and returns and json""" + def table_accessible(database, full_table_name, schema_name=None): + table_name_pieces = full_table_name.split(".") + if len(table_name_pieces) == 2: + table_schema = table_name_pieces[0] + table_name = table_name_pieces[1] + else: + table_schema = schema_name + table_name = table_name_pieces[0] + return self.datasource_access_by_name( + database, table_name, schema=table_schema) + async = request.form.get('runAsync') == 'true' sql = request.form.get('sql') database_id = request.form.get('database_id') session = db.session() - mydb = session.query(models.Database).filter_by(id=database_id).first() + mydb = session.query(models.Database).filter_by(id=database_id).one() if not mydb: json_error_response( 'Database with id {} is missing.'.format(database_id)) - if not self.database_access(mydb): + superset_query = sql_parse.SupersetQuery(sql) + schema = request.form.get('schema') + schema = schema if schema else None + + rejected_tables = [ + t for t in superset_query.tables if not + table_accessible(mydb, t, schema_name=schema)] + if rejected_tables: json_error_response( - get_database_access_error_msg(mydb.database_name)) + get_datasource_access_error_msg('{}'.format(rejected_tables))) session.commit() query = models.Query( database_id=int(database_id), limit=int(app.config.get('SQL_MAX_ROW', None)), sql=sql, - schema=request.form.get('schema'), + schema=schema, select_as_cta=request.form.get('select_as_cta') == 'true', start_time=utils.now_as_float(), tab_name=request.form.get('tab'), @@ -2341,7 +2371,8 @@ def sql_json(self): if async: # Ignore the celery future object and the request may time out. sql_lab.get_sql_results.delay( - query_id, return_results=False, store_results=not query.select_as_cta) + query_id, return_results=False, + store_results=not query.select_as_cta) return Response( json.dumps({'query': query.to_dict()}, default=utils.json_int_dttm_ser, diff --git a/tests/sql_parse_tests.py b/tests/sql_parse_tests.py new file mode 100644 index 0000000000000..284e16845f61a --- /dev/null +++ b/tests/sql_parse_tests.py @@ -0,0 +1,295 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import unittest + +from superset import sql_parse + + +class SupersetTestCase(unittest.TestCase): + + def extract_tables(self, query): + sq = sql_parse.SupersetQuery(query) + return sq.tables + + def test_simple_select(self): + query = "SELECT * FROM tbname" + self.assertEquals({"tbname"}, self.extract_tables(query)) + + # underscores + query = "SELECT * FROM tb_name" + self.assertEquals({"tb_name"}, + self.extract_tables(query)) + + # quotes + query = 'SELECT * FROM "tbname"' + self.assertEquals({"tbname"}, self.extract_tables(query)) + + # schema + self.assertEquals( + {"schemaname.tbname"}, + self.extract_tables("SELECT * FROM schemaname.tbname")) + + # quotes + query = "SELECT field1, field2 FROM tb_name" + self.assertEquals({"tb_name"}, self.extract_tables(query)) + + query = "SELECT t1.f1, t2.f2 FROM t1, t2" + self.assertEquals({"t1", "t2"}, self.extract_tables(query)) + + def test_select_named_table(self): + query = "SELECT a.date, a.field FROM left_table a LIMIT 10" + self.assertEquals( + {"left_table"}, self.extract_tables(query)) + + def test_reverse_select(self): + query = "FROM t1 SELECT field" + self.assertEquals({"t1"}, self.extract_tables(query)) + + def test_subselect(self): + query = """ + SELECT sub.* + FROM ( + SELECT * + FROM s1.t1 + WHERE day_of_week = 'Friday' + ) sub, s2.t2 + WHERE sub.resolution = 'NONE' + """ + self.assertEquals({"s1.t1", "s2.t2"}, + self.extract_tables(query)) + + query = """ + SELECT sub.* + FROM ( + SELECT * + FROM s1.t1 + WHERE day_of_week = 'Friday' + ) sub + WHERE sub.resolution = 'NONE' + """ + self.assertEquals({"s1.t1"}, self.extract_tables(query)) + + query = """ + SELECT * FROM t1 + WHERE s11 > ANY + (SELECT COUNT(*) /* no hint */ FROM t2 + WHERE NOT EXISTS + (SELECT * FROM t3 + WHERE ROW(5*t2.s1,77)= + (SELECT 50,11*s1 FROM t4))); + """ + self.assertEquals({"t1", "t2", "t3", "t4"}, + self.extract_tables(query)) + + def test_select_in_expression(self): + query = "SELECT f1, (SELECT count(1) FROM t2) FROM t1" + self.assertEquals({"t1", "t2"}, self.extract_tables(query)) + + def test_union(self): + query = "SELECT * FROM t1 UNION SELECT * FROM t2" + self.assertEquals({"t1", "t2"}, self.extract_tables(query)) + + query = "SELECT * FROM t1 UNION ALL SELECT * FROM t2" + self.assertEquals({"t1", "t2"}, self.extract_tables(query)) + + query = "SELECT * FROM t1 INTERSECT ALL SELECT * FROM t2" + self.assertEquals({"t1", "t2"}, self.extract_tables(query)) + + def test_select_from_values(self): + query = "SELECT * FROM VALUES (13, 42)" + self.assertFalse(self.extract_tables(query)) + + def test_select_array(self): + query = """ + SELECT ARRAY[1, 2, 3] AS my_array + FROM t1 LIMIT 10 + """ + self.assertEquals({"t1"}, self.extract_tables(query)) + + def test_select_if(self): + query = """ + SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL) + FROM t1 LIMIT 10 + """ + self.assertEquals({"t1"}, self.extract_tables(query)) + + # SHOW TABLES ((FROM | IN) qualifiedName)? (LIKE pattern=STRING)? + def test_show_tables(self): + query = 'SHOW TABLES FROM s1 like "%order%"' + # TODO: figure out what should code do here + self.assertEquals({"s1"}, self.extract_tables(query)) + + # SHOW COLUMNS (FROM | IN) qualifiedName + def test_show_columns(self): + query = "SHOW COLUMNS FROM t1" + self.assertEquals({"t1"}, self.extract_tables(query)) + + def test_where_subquery(self): + query = """ + SELECT name + FROM t1 + WHERE regionkey = (SELECT max(regionkey) FROM t2) + """ + self.assertEquals({"t1", "t2"}, self.extract_tables(query)) + + query = """ + SELECT name + FROM t1 + WHERE regionkey IN (SELECT regionkey FROM t2) + """ + self.assertEquals({"t1", "t2"}, self.extract_tables(query)) + + query = """ + SELECT name + FROM t1 + WHERE regionkey EXISTS (SELECT regionkey FROM t2) + """ + self.assertEquals({"t1", "t2"}, self.extract_tables(query)) + + # DESCRIBE | DESC qualifiedName + def test_describe(self): + self.assertEquals({"t1"}, self.extract_tables("DESCRIBE t1")) + self.assertEquals({"t1"}, self.extract_tables("DESC t1")) + + # SHOW PARTITIONS FROM qualifiedName (WHERE booleanExpression)? + # (ORDER BY sortItem (',' sortItem)*)? (LIMIT limit=(INTEGER_VALUE | ALL))? + def test_show_partitions(self): + query = """ + SHOW PARTITIONS FROM orders + WHERE ds >= '2013-01-01' ORDER BY ds DESC; + """ + self.assertEquals({"orders"}, self.extract_tables(query)) + + def test_join(self): + query = "SELECT t1.*, t2.* FROM t1 JOIN t2 ON t1.a = t2.a;" + self.assertEquals({"t1", "t2"}, self.extract_tables(query)) + + # subquery + join + query = """ + SELECT a.date, b.name FROM + left_table a + JOIN ( + SELECT + CAST((b.year) as VARCHAR) date, + name + FROM right_table + ) b + ON a.date = b.date + """ + self.assertEquals({"left_table", "right_table"}, + self.extract_tables(query)) + + query = """ + SELECT a.date, b.name FROM + left_table a + LEFT INNER JOIN ( + SELECT + CAST((b.year) as VARCHAR) date, + name + FROM right_table + ) b + ON a.date = b.date + """ + self.assertEquals({"left_table", "right_table"}, + self.extract_tables(query)) + + query = """ + SELECT a.date, b.name FROM + left_table a + RIGHT OUTER JOIN ( + SELECT + CAST((b.year) as VARCHAR) date, + name + FROM right_table + ) b + ON a.date = b.date + """ + self.assertEquals({"left_table", "right_table"}, + self.extract_tables(query)) + + query = """ + SELECT a.date, b.name FROM + left_table a + FULL OUTER JOIN ( + SELECT + CAST((b.year) as VARCHAR) date, + name + FROM right_table + ) b + ON a.date = b.date + """ + self.assertEquals({"left_table", "right_table"}, + self.extract_tables(query)) + + # TODO: add SEMI join support, SQL Parse does not handle it. + # query = """ + # SELECT a.date, b.name FROM + # left_table a + # LEFT SEMI JOIN ( + # SELECT + # CAST((b.year) as VARCHAR) date, + # name + # FROM right_table + # ) b + # ON a.date = b.date + # """ + # self.assertEquals({"left_table", "right_table"}, + # sql_parse.extract_tables(query)) + + def test_combinations(self): + query = """ + SELECT * FROM t1 + WHERE s11 > ANY + (SELECT * FROM t1 UNION ALL SELECT * FROM ( + SELECT t6.*, t3.* FROM t6 JOIN t3 ON t6.a = t3.a) tmp_join + WHERE NOT EXISTS + (SELECT * FROM t3 + WHERE ROW(5*t3.s1,77)= + (SELECT 50,11*s1 FROM t4))); + """ + self.assertEquals({"t1", "t3", "t4", "t6"}, + self.extract_tables(query)) + + query = """ + SELECT * FROM (SELECT * FROM (SELECT * FROM (SELECT * FROM EmployeeS) + AS S1) AS S2) AS S3; + """ + self.assertEquals({"EmployeeS"}, self.extract_tables(query)) + + def test_with(self): + query = """ + WITH + x AS (SELECT a FROM t1), + y AS (SELECT a AS b FROM t2), + z AS (SELECT b AS c FROM t3) + SELECT c FROM z; + """ + self.assertEquals({"t1", "t2", "t3"}, + self.extract_tables(query)) + + query = """ + WITH + x AS (SELECT a FROM t1), + y AS (SELECT a AS b FROM x), + z AS (SELECT b AS c FROM y) + SELECT c FROM z; + """ + self.assertEquals({"t1"}, self.extract_tables(query)) + + def test_reusing_aliases(self): + query = """ + with q1 as ( select key from q2 where key = '5'), + q2 as ( select key from src where key = '5') + select * from (select key from q1) a; + """ + self.assertEquals({"src"}, self.extract_tables(query)) + + def multistatement(self): + query = "SELECT * FROM t1; SELECT * FROM t2" + self.assertEquals({"t1", "t2"}, self.extract_tables(query)) + + query = "SELECT * FROM t1; SELECT * FROM t2;" + self.assertEquals({"t1", "t2"}, self.extract_tables(query))