-
Notifications
You must be signed in to change notification settings - Fork 14.6k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
af176ee
commit ca046d3
Showing
6 changed files
with
43 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,15 +6,14 @@ | |
from __future__ import unicode_literals | ||
|
||
import json | ||
import os | ||
import subprocess | ||
import time | ||
import unittest | ||
|
||
import pandas as pd | ||
from past.builtins import basestring | ||
|
||
from superset import app, db, security_manager | ||
from superset import app, db | ||
from superset.models.helpers import QueryStatus | ||
from superset.models.sql_lab import Query | ||
from superset.sql_parse import SupersetQuery | ||
|
@@ -23,13 +22,12 @@ | |
|
||
|
||
BASE_DIR = app.config.get('BASE_DIR') | ||
CELERY_SLEEP_TIME = 5 | ||
|
||
|
||
class CeleryConfig(object): | ||
BROKER_URL = 'sqla+sqlite:///' + app.config.get('SQL_CELERY_DB_FILE_PATH') | ||
BROKER_URL = app.config.get('CELERY_RESULT_BACKEND') | ||
CELERY_IMPORTS = ('superset.sql_lab', ) | ||
CELERY_RESULT_BACKEND = ( | ||
'db+sqlite:///' + app.config.get('SQL_CELERY_RESULTS_DB_FILE_PATH')) | ||
CELERY_ANNOTATIONS = {'sql_lab.add': {'rate_limit': '10/s'}} | ||
CONCURRENCY = 1 | ||
|
||
|
@@ -91,28 +89,11 @@ def get_query_by_id(self, id): | |
def setUpClass(cls): | ||
db.session.query(Query).delete() | ||
db.session.commit() | ||
try: | ||
os.remove(app.config.get('SQL_CELERY_DB_FILE_PATH')) | ||
except OSError as e: | ||
app.logger.warn(str(e)) | ||
try: | ||
os.remove(app.config.get('SQL_CELERY_RESULTS_DB_FILE_PATH')) | ||
except OSError as e: | ||
app.logger.warn(str(e)) | ||
|
||
security_manager.sync_role_definitions() | ||
|
||
worker_command = BASE_DIR + '/bin/superset worker' | ||
|
||
worker_command = BASE_DIR + '/bin/superset worker -w 2' | ||
subprocess.Popen( | ||
worker_command, shell=True, stdout=subprocess.PIPE) | ||
|
||
admin = security_manager.find_user('admin') | ||
if not admin: | ||
security_manager.add_user( | ||
'admin', 'admin', ' user', '[email protected]', | ||
security_manager.find_role('Admin'), | ||
password='general') | ||
|
||
@classmethod | ||
def tearDownClass(cls): | ||
subprocess.call( | ||
|
@@ -124,7 +105,7 @@ def tearDownClass(cls): | |
shell=True, | ||
) | ||
|
||
def run_sql(self, db_id, sql, client_id, cta='false', tmp_table='tmp', | ||
def run_sql(self, db_id, sql, client_id=None, cta='false', tmp_table='tmp', | ||
async_='false'): | ||
self.login() | ||
resp = self.client.post( | ||
|
@@ -142,7 +123,8 @@ def run_sql(self, db_id, sql, client_id, cta='false', tmp_table='tmp', | |
return json.loads(resp.data.decode('utf-8')) | ||
|
||
def test_run_sync_query_dont_exist(self): | ||
db_id = get_main_database(db.session).id | ||
main_db = get_main_database(db.session) | ||
db_id = main_db.id | ||
sql_dont_exist = 'SELECT name FROM table_dont_exist' | ||
result1 = self.run_sql(db_id, sql_dont_exist, '1', cta='true') | ||
self.assertTrue('error' in result1) | ||
|
@@ -151,11 +133,13 @@ def test_run_sync_query_cta(self): | |
main_db = get_main_database(db.session) | ||
db_id = main_db.id | ||
eng = main_db.get_sqla_engine() | ||
tmp_table_name = 'tmp_async_22' | ||
self.drop_table_if_exists(tmp_table_name, main_db) | ||
perm_name = 'can_sql_json' | ||
sql_where = ( | ||
"SELECT name FROM ab_permission WHERE name='{}'".format(perm_name)) | ||
result2 = self.run_sql( | ||
db_id, sql_where, '2', tmp_table='tmp_table_2', cta='true') | ||
db_id, sql_where, '2', tmp_table=tmp_table_name, cta='true') | ||
self.assertEqual(QueryStatus.SUCCESS, result2['query']['state']) | ||
self.assertEqual([], result2['data']) | ||
self.assertEqual([], result2['columns']) | ||
|
@@ -167,34 +151,42 @@ def test_run_sync_query_cta(self): | |
self.assertEqual([{'name': perm_name}], data2) | ||
|
||
def test_run_sync_query_cta_no_data(self): | ||
db_id = get_main_database(db.session).id | ||
sql_empty_result = 'SELECT * FROM ab_user WHERE id=666 LIMIT 666' | ||
result3 = self.run_sql( | ||
db_id, sql_empty_result, '3', cta='false') | ||
main_db = get_main_database(db.session) | ||
db_id = main_db.id | ||
sql_empty_result = 'SELECT * FROM ab_user WHERE id=666' | ||
result3 = self.run_sql(db_id, sql_empty_result, '3') | ||
self.assertEqual(QueryStatus.SUCCESS, result3['query']['state']) | ||
self.assertEqual([], result3['data']) | ||
self.assertEqual([], result3['columns']) | ||
|
||
query = self.get_query_by_id(result3['query']['serverId']) | ||
self.assertEqual(QueryStatus.SUCCESS, query.status) | ||
self.assertEqual(666, query.limit) | ||
query3 = self.get_query_by_id(result3['query']['serverId']) | ||
self.assertEqual(QueryStatus.SUCCESS, query3.status) | ||
|
||
def drop_table_if_exists(self, table_name, database=None): | ||
"""Drop table if it exists, works on any DB""" | ||
sql = 'DROP TABLE {}'.format(table_name) | ||
db_id = database.id | ||
if database: | ||
database.allow_dml = True | ||
db.session.flush() | ||
return self.run_sql(db_id, sql) | ||
|
||
def test_run_async_query(self): | ||
main_db = get_main_database(db.session) | ||
eng = main_db.get_sqla_engine() | ||
db_id = main_db.id | ||
|
||
self.drop_table_if_exists('tmp_async_1', main_db) | ||
|
||
sql_where = "SELECT name FROM ab_role WHERE name='Admin'" | ||
result = self.run_sql( | ||
main_db.id, sql_where, '4', async_='true', tmp_table='tmp_async_1', | ||
db_id, sql_where, '4', async_='true', tmp_table='tmp_async_1', | ||
cta='true') | ||
assert result['query']['state'] in ( | ||
QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS) | ||
|
||
time.sleep(1) | ||
time.sleep(CELERY_SLEEP_TIME) | ||
|
||
query = self.get_query_by_id(result['query']['serverId']) | ||
df = pd.read_sql_query(query.select_sql, con=eng) | ||
self.assertEqual(QueryStatus.SUCCESS, query.status) | ||
self.assertEqual([{'name': 'Admin'}], df.to_dict(orient='records')) | ||
self.assertEqual(QueryStatus.SUCCESS, query.status) | ||
self.assertTrue('FROM tmp_async_1' in query.select_sql) | ||
self.assertEqual( | ||
|
@@ -208,20 +200,19 @@ def test_run_async_query(self): | |
|
||
def test_run_async_query_with_lower_limit(self): | ||
main_db = get_main_database(db.session) | ||
eng = main_db.get_sqla_engine() | ||
db_id = main_db.id | ||
self.drop_table_if_exists('tmp_async_2', main_db) | ||
|
||
sql_where = "SELECT name FROM ab_role WHERE name='Alpha' LIMIT 1" | ||
result = self.run_sql( | ||
main_db.id, sql_where, '5', async_='true', tmp_table='tmp_async_2', | ||
db_id, sql_where, '5', async_='true', tmp_table='tmp_async_2', | ||
cta='true') | ||
assert result['query']['state'] in ( | ||
QueryStatus.PENDING, QueryStatus.RUNNING, QueryStatus.SUCCESS) | ||
|
||
time.sleep(1) | ||
time.sleep(CELERY_SLEEP_TIME) | ||
|
||
query = self.get_query_by_id(result['query']['serverId']) | ||
df = pd.read_sql_query(query.select_sql, con=eng) | ||
self.assertEqual(QueryStatus.SUCCESS, query.status) | ||
self.assertEqual([{'name': 'Alpha'}], df.to_dict(orient='records')) | ||
self.assertEqual(QueryStatus.SUCCESS, query.status) | ||
self.assertTrue('FROM tmp_async_2' in query.select_sql) | ||
self.assertEqual( | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters