Skip to content

Commit

Permalink
Escaping the user's SQL in the explore view (#3186)
Browse files Browse the repository at this point in the history
* Escaping the user's SQL in the explore view

When executing SQL from SQL Lab, we use a lower level API to the
database which doesn't require escaping the SQL. When going through
the explore view, the stack chain leading to the same method may need
escaping depending on how the DBAPI driver is written, and that is the
case for Presto (and perhaps other drivers).

* Using regex to avoid doubling doubles
  • Loading branch information
mistercrunch authored Jul 27, 2017
1 parent fb866a9 commit 25c599d
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
16 changes: 10 additions & 6 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,10 +285,12 @@ def values_for_column(self, column_name, limit=10000):
"""
cols = {col.column_name: col for col in self.columns}
target_col = cols[column_name]
tp = self.get_template_processor()
db_engine_spec = self.database.db_engine_spec

qry = (
select([target_col.sqla_col])
.select_from(self.get_from_clause())
.select_from(self.get_from_clause(tp, db_engine_spec))
.distinct(column_name)
)
if limit:
Expand Down Expand Up @@ -322,7 +324,6 @@ def get_query_str(self, query_obj):
)
logging.info(sql)
sql = sqlparse.format(sql, reindent=True)
sql = self.database.db_engine_spec.sql_preprocessor(sql)
return sql

def get_sqla_table(self):
Expand All @@ -331,12 +332,14 @@ def get_sqla_table(self):
tbl.schema = self.schema
return tbl

def get_from_clause(self, template_processor=None):
def get_from_clause(self, template_processor=None, db_engine_spec=None):
# Supporting arbitrary SQL statements in place of tables
if self.sql:
from_sql = self.sql
if template_processor:
from_sql = template_processor.process_template(from_sql)
if db_engine_spec:
from_sql = db_engine_spec.escape_sql(from_sql)
return TextAsFrom(sa.text(from_sql), []).alias('expr_qry')
return self.get_sqla_table()

Expand Down Expand Up @@ -367,13 +370,14 @@ def get_sqla_query( # sqla
'form_data': form_data,
}
template_processor = self.get_template_processor(**template_kwargs)
db_engine_spec = self.database.db_engine_spec

# For backward compatibility
if granularity not in self.dttm_cols:
granularity = self.main_dttm_col

# Database spec supports join-free timeslot grouping
time_groupby_inline = self.database.db_engine_spec.time_groupby_inline
time_groupby_inline = db_engine_spec.time_groupby_inline

cols = {col.column_name: col for col in self.columns}
metrics_dict = {m.metric_name: m for m in self.metrics}
Expand Down Expand Up @@ -428,7 +432,7 @@ def get_sqla_query( # sqla
groupby_exprs += [timestamp]

# Use main dttm column to support index with secondary dttm columns
if self.database.db_engine_spec.time_secondary_columns and \
if db_engine_spec.time_secondary_columns and \
self.main_dttm_col in self.dttm_cols and \
self.main_dttm_col != dttm_col.column_name:
time_filters.append(cols[self.main_dttm_col].
Expand All @@ -438,7 +442,7 @@ def get_sqla_query( # sqla
select_exprs += metrics_exprs
qry = sa.select(select_exprs)

tbl = self.get_from_clause(template_processor)
tbl = self.get_from_clause(template_processor, db_engine_spec)

if not columns:
qry = qry.group_by(*groupby_exprs)
Expand Down
17 changes: 9 additions & 8 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def extra_table_metadata(cls, database, table_name, schema_name):
"""Returns engine-specific table metadata"""
return {}

@classmethod
def escape_sql(cls, sql):
"""Escapes the raw SQL"""
return sql

@classmethod
def convert_dttm(cls, target_type, dttm):
return "'{}'".format(dttm.strftime('%Y-%m-%d %H:%M:%S'))
Expand Down Expand Up @@ -139,14 +144,6 @@ def adjust_database_uri(cls, uri, selected_schema):
"""
return uri

@classmethod
def sql_preprocessor(cls, sql):
"""If the SQL needs to be altered prior to running it
For example Presto needs to double `%` characters
"""
return sql

@classmethod
def patch(cls):
pass
Expand Down Expand Up @@ -399,6 +396,10 @@ def adjust_database_uri(cls, uri, selected_schema=None):
uri.database = database
return uri

@classmethod
def escape_sql(cls, sql):
return re.sub(r'%%|%', "%%", sql)

@classmethod
def convert_dttm(cls, target_type, dttm):
tt = target_type.upper()
Expand Down
1 change: 0 additions & 1 deletion superset/sql_lab.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,6 @@ def handle_error(msg):
template_processor = get_template_processor(
database=database, query=query)
executed_sql = template_processor.process_template(executed_sql)
executed_sql = db_engine_spec.sql_preprocessor(executed_sql)
except Exception as e:
logging.exception(e)
msg = "Template rendering failed: " + utils.error_msg_from_exception(e)
Expand Down

0 comments on commit 25c599d

Please sign in to comment.