Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(sqla): apply jinja to metrics #19565

Merged
merged 1 commit into from
Apr 7, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 52 additions & 31 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def get_timestamp_expression(

:param time_grain: Optional time grain, e.g. P1Y
:param label: alias/label that column is expected to have
:param template_processor: template processor
:return: A TimeExpression object wrapped in a Label if supported by db
"""
label = label or utils.DTTM_ALIAS
Expand Down Expand Up @@ -488,6 +489,27 @@ def data(self) -> Dict[str, Any]:
)


def _process_sql_expression(
expression: Optional[str],
database_id: int,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
) -> Optional[str]:
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
return expression


class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods
"""An ORM object for SqlAlchemy table references"""

Expand Down Expand Up @@ -875,13 +897,17 @@ def get_rendered_sql(
return sql

def adhoc_metric_to_sqla(
self, metric: AdhocMetric, columns_by_name: Dict[str, TableColumn]
self,
metric: AdhocMetric,
columns_by_name: Dict[str, TableColumn],
template_processor: Optional[BaseTemplateProcessor] = None,
) -> ColumnElement:
"""
Turn an adhoc metric into a sqlalchemy column.

:param dict metric: Adhoc metric definition
:param dict columns_by_name: Columns for the current table
:param template_processor: template_processor instance
:returns: The metric defined as a sqlalchemy column
:rtype: sqlalchemy.sql.column
"""
Expand All @@ -898,17 +924,12 @@ def adhoc_metric_to_sqla(
sqla_column = column(column_name)
sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column)
elif expression_type == utils.AdhocMetricExpressionType.SQL:
tp = self.get_template_processor()
expression = tp.process_template(cast(str, metric["sqlExpression"]))
expression = validate_adhoc_subquery(
expression,
self.database_id,
self.schema,
expression = _process_sql_expression(
expression=metric["sqlExpression"],
database_id=self.database_id,
schema=self.schema,
template_processor=template_processor,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sqla_metric = literal_column(expression)
else:
raise QueryObjectValidationError("Adhoc metric expressionType is invalid")
Expand All @@ -929,21 +950,14 @@ def adhoc_column_to_sqla(
:rtype: sqlalchemy.sql.column
"""
label = utils.get_column_name(col)
expression = col["sqlExpression"]
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
expression = validate_adhoc_subquery(
expression,
self.database_id,
self.schema,
)
try:
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
raise QueryObjectValidationError(ex.message) from ex
sqla_metric = literal_column(expression)
return self.make_sqla_column_compatible(sqla_metric, label)
expression = _process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
schema=self.schema,
template_processor=template_processor,
)
sqla_column = literal_column(expression)
return self.make_sqla_column_compatible(sqla_column, label)

def make_sqla_column_compatible(
self, sqla_col: ColumnElement, label: Optional[str] = None
Expand Down Expand Up @@ -1127,7 +1141,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
for metric in metrics:
if utils.is_adhoc_metric(metric):
assert isinstance(metric, dict)
metrics_exprs.append(self.adhoc_metric_to_sqla(metric, columns_by_name))
metrics_exprs.append(
self.adhoc_metric_to_sqla(
metric=metric,
columns_by_name=columns_by_name,
template_processor=template_processor,
)
)
elif isinstance(metric, str) and metric in metrics_by_name:
metrics_exprs.append(metrics_by_name[metric].get_sqla_col())
else:
Expand All @@ -1154,10 +1174,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
if isinstance(col, dict):
col = cast(AdhocMetric, col)
if col.get("sqlExpression"):
col["sqlExpression"] = validate_adhoc_subquery(
cast(str, col["sqlExpression"]),
self.database_id,
self.schema,
col["sqlExpression"] = _process_sql_expression(
expression=col["sqlExpression"],
database_id=self.database_id,
schema=self.schema,
template_processor=template_processor,
)
if utils.is_adhoc_metric(col):
# add adhoc sort by column to columns_by_name if not exists
Expand Down