From 44eb81e35f769b2f3f3224a3c8e6b2045e48e5cc Mon Sep 17 00:00:00 2001 From: Ville Brofeldt <33317356+villebro@users.noreply.github.com> Date: Thu, 7 Apr 2022 14:04:51 +0300 Subject: [PATCH] fix(sqla): apply jinja to metrics (#19565) (cherry picked from commit 34b55765c4b0cbd8f0b9f89c6ca0f62f4478270e) --- superset/connectors/sqla/models.py | 83 +++++++++++++++++++----------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 8721f6ea8178b..b8d3a7d091423 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -354,6 +354,7 @@ def get_timestamp_expression( :param time_grain: Optional time grain, e.g. P1Y :param label: alias/label that column is expected to have + :param template_processor: template processor :return: A TimeExpression object wrapped in a Label if supported by db """ label = label or utils.DTTM_ALIAS @@ -517,6 +518,27 @@ def data(self) -> Dict[str, Any]: ) +def _process_sql_expression( + expression: Optional[str], + database_id: int, + schema: str, + template_processor: Optional[BaseTemplateProcessor], +) -> Optional[str]: + if template_processor and expression: + expression = template_processor.process_template(expression) + if expression: + expression = validate_adhoc_subquery( + expression, + database_id, + schema, + ) + try: + expression = sanitize_clause(expression) + except QueryClauseValidationException as ex: + raise QueryObjectValidationError(ex.message) from ex + return expression + + class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-methods """An ORM object for SqlAlchemy table references""" @@ -899,13 +921,17 @@ def get_rendered_sql( return sql def adhoc_metric_to_sqla( - self, metric: AdhocMetric, columns_by_name: Dict[str, TableColumn] + self, + metric: AdhocMetric, + columns_by_name: Dict[str, TableColumn], + template_processor: Optional[BaseTemplateProcessor] = None, ) -> ColumnElement: """ Turn an adhoc metric into a sqlalchemy column. :param dict metric: Adhoc metric definition :param dict columns_by_name: Columns for the current table + :param template_processor: template_processor instance :returns: The metric defined as a sqlalchemy column :rtype: sqlalchemy.sql.column """ @@ -922,17 +948,12 @@ def adhoc_metric_to_sqla( sqla_column = column(column_name) sqla_metric = self.sqla_aggregations[metric["aggregate"]](sqla_column) elif expression_type == utils.AdhocMetricExpressionType.SQL: - tp = self.get_template_processor() - expression = tp.process_template(cast(str, metric["sqlExpression"])) - expression = validate_adhoc_subquery( - expression, - self.database_id, - self.schema, + expression = _process_sql_expression( + expression=metric["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, ) - try: - expression = sanitize_clause(expression) - except QueryClauseValidationException as ex: - raise QueryObjectValidationError(ex.message) from ex sqla_metric = literal_column(expression) else: raise QueryObjectValidationError("Adhoc metric expressionType is invalid") @@ -953,21 +974,14 @@ def adhoc_column_to_sqla( :rtype: sqlalchemy.sql.column """ label = utils.get_column_name(col) - expression = col["sqlExpression"] - if template_processor and expression: - expression = template_processor.process_template(expression) - if expression: - expression = validate_adhoc_subquery( - expression, - self.database_id, - self.schema, - ) - try: - expression = sanitize_clause(expression) - except QueryClauseValidationException as ex: - raise QueryObjectValidationError(ex.message) from ex - sqla_metric = literal_column(expression) - return self.make_sqla_column_compatible(sqla_metric, label) + expression = _process_sql_expression( + expression=col["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, + ) + sqla_column = literal_column(expression) + return self.make_sqla_column_compatible(sqla_column, label) def make_sqla_column_compatible( self, sqla_col: ColumnElement, label: Optional[str] = None @@ -1151,7 +1165,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma for metric in metrics: if utils.is_adhoc_metric(metric): assert isinstance(metric, dict) - metrics_exprs.append(self.adhoc_metric_to_sqla(metric, columns_by_name)) + metrics_exprs.append( + self.adhoc_metric_to_sqla( + metric=metric, + columns_by_name=columns_by_name, + template_processor=template_processor, + ) + ) elif isinstance(metric, str) and metric in metrics_by_name: metrics_exprs.append(metrics_by_name[metric].get_sqla_col()) else: @@ -1178,10 +1198,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma if isinstance(col, dict): col = cast(AdhocMetric, col) if col.get("sqlExpression"): - col["sqlExpression"] = validate_adhoc_subquery( - cast(str, col["sqlExpression"]), - self.database_id, - self.schema, + col["sqlExpression"] = _process_sql_expression( + expression=col["sqlExpression"], + database_id=self.database_id, + schema=self.schema, + template_processor=template_processor, ) if utils.is_adhoc_metric(col): # add adhoc sort by column to columns_by_name if not exists