diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 86e171ba44d77..32ad61f6927a9 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -269,7 +269,6 @@ def execute_sql_statements( query.rows = cdf.size query.progress = 100 query.set_extra_json_key('progress', None) - query.status = QueryStatus.SUCCESS if query.select_as_cta: query.select_sql = database.select_star( query.tmp_table_name, @@ -285,13 +284,14 @@ def execute_sql_statements( selected_columns, data) payload.update({ - 'status': query.status, + 'status': QueryStatus.SUCCESS, 'data': data, 'columns': all_columns, 'selected_columns': selected_columns, 'expanded_columns': expanded_columns, 'query': query.to_dict(), }) + payload['query']['state'] = QueryStatus.SUCCESS if store_results: key = str(uuid.uuid4()) @@ -304,6 +304,8 @@ def execute_sql_statements( cache_timeout = config.get('CACHE_DEFAULT_TIMEOUT', 0) results_backend.set(key, zlib_compress(json_payload), cache_timeout) query.results_key = key + + query.status = QueryStatus.SUCCESS session.commit() if return_results: diff --git a/tests/celery_tests.py b/tests/celery_tests.py index bccffb21144cc..320697de1fa8f 100644 --- a/tests/celery_tests.py +++ b/tests/celery_tests.py @@ -127,7 +127,7 @@ def run_sql(self, db_id, sql, client_id=None, cta='false', tmp_table='tmp', ), ) self.logout() - return json.loads(resp.data.decode('utf-8')) + return json.loads(resp.data) def test_run_sync_query_dont_exist(self): main_db = get_main_database(db.session) @@ -145,12 +145,12 @@ def test_run_sync_query_cta(self): perm_name = 'can_sql_json' sql_where = ( "SELECT name FROM ab_permission WHERE name='{}'".format(perm_name)) - result2 = self.run_sql( + result = self.run_sql( db_id, sql_where, '2', tmp_table=tmp_table_name, cta='true') - self.assertEqual(QueryStatus.SUCCESS, result2['query']['state']) - self.assertEqual([], result2['data']) - self.assertEqual([], result2['columns']) - query2 = self.get_query_by_id(result2['query']['serverId']) + self.assertEqual(QueryStatus.SUCCESS, result['query']['state']) + self.assertEqual([], result['data']) + self.assertEqual([], result['columns']) + query2 = self.get_query_by_id(result['query']['serverId']) # Check the data in the tmp table. if backend != 'postgresql':