Skip to content

Commit

Permalink
chore(sqla): refactor query utils (#21811)
Browse files Browse the repository at this point in the history
Co-authored-by: Ville Brofeldt <[email protected]>
  • Loading branch information
2 people authored and michael-s-molina committed Jan 4, 2023
1 parent 4c21c7b commit 7c98e26
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 15 deletions.
26 changes: 19 additions & 7 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
from superset.exceptions import (
QueryClauseValidationException,
QueryObjectValidationError,
SupersetSecurityException,
)
from superset.jinja_context import (
BaseTemplateProcessor,
Expand Down Expand Up @@ -514,19 +515,19 @@ def _process_sql_expression(
expression: Optional[str],
database_id: int,
schema: str,
template_processor: Optional[BaseTemplateProcessor],
template_processor: Optional[BaseTemplateProcessor] = None,
) -> Optional[str]:
if template_processor and expression:
expression = template_processor.process_template(expression)
if expression:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
try:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
expression = sanitize_clause(expression)
except QueryClauseValidationException as ex:
except (QueryClauseValidationException, SupersetSecurityException) as ex:
raise QueryObjectValidationError(ex.message) from ex
return expression

Expand Down Expand Up @@ -1465,6 +1466,11 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
msg=ex.message,
)
) from ex
where = _process_sql_expression(
expression=where,
database_id=self.database_id,
schema=self.schema,
)
where_clause_and += [self.text(where)]
having = extras.get("having")
if having:
Expand All @@ -1477,7 +1483,13 @@ def get_sqla_query( # pylint: disable=too-many-arguments,too-many-locals,too-ma
msg=ex.message,
)
) from ex
having = _process_sql_expression(
expression=having,
database_id=self.database_id,
schema=self.schema,
)
having_clause_and += [self.text(having)]

if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate())
if granularity:
Expand Down
7 changes: 7 additions & 0 deletions superset/connectors/sqla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from flask_babel import lazy_gettext as _
from sqlalchemy import and_, inspect, or_
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL as SqlaURL
from sqlalchemy.exc import NoSuchTableError
from sqlalchemy.orm import Session
from sqlalchemy.sql.type_api import TypeEngine
Expand All @@ -37,6 +38,7 @@
from superset.sql_parse import has_table_query, insert_rls, ParsedQuery, Table
from superset.superset_typing import ResultSetColumnType
from superset.tables.models import Table as NewTable
from superset.utils.memoized import memoized

if TYPE_CHECKING:
from superset.connectors.sqla.models import SqlaTable
Expand Down Expand Up @@ -252,3 +254,8 @@ def load_or_create_tables( # pylint: disable=too-many-arguments
existing.add((table.schema, table.table))

return new_tables


@memoized
def get_identifier_quoter(drivername: str) -> Dict[str, Callable[[str], str]]:
return SqlaURL(drivername=drivername).get_dialect()().identifier_preparer.quote
11 changes: 10 additions & 1 deletion superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import textwrap
from ast import literal_eval
from contextlib import closing
from contextlib import closing, contextmanager
from copy import deepcopy
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
Expand Down Expand Up @@ -345,6 +345,15 @@ def get_effective_user(
effective_username = g.user.username
return effective_username

@contextmanager
def get_sqla_engine_with_context(
self,
schema: Optional[str] = None,
nullpool: bool = True,
source: Optional[utils.QuerySource] = None,
) -> Engine:
yield self.get_sqla_engine(schema=schema, nullpool=nullpool, source=source)

@memoized(
watch=(
"impersonate_user",
Expand Down
71 changes: 70 additions & 1 deletion tests/integration_tests/charts/data/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import copy
from datetime import datetime
from io import BytesIO
from typing import Optional
from typing import Optional, Dict, Any
from unittest import mock
from zipfile import ZipFile

Expand Down Expand Up @@ -974,3 +974,72 @@ def test_chart_data_with_adhoc_column(self):
unique_genders = {row["male_or_female"] for row in data}
assert unique_genders == {"male", "female"}
assert result["applied_filters"] == [{"column": "male_or_female"}]


@pytest.fixture()
def physical_query_context(physical_dataset) -> Dict[str, Any]:
return {
"datasource": {
"type": physical_dataset.type,
"id": physical_dataset.id,
},
"queries": [
{
"columns": ["col1"],
"metrics": ["count"],
"orderby": [["col1", True]],
}
],
"result_type": ChartDataResultType.FULL,
"force": True,
}


@pytest.mark.parametrize(
"status_code,extras",
[
(200, {"where": "1 = 1"}),
(200, {"having": "count(*) > 0"}),
(400, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
(400, {"having": "count(*) > (select count(*) from physical_dataset)"}),
],
)
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_subquery_not_allowed(
test_client,
login_as_admin,
physical_dataset,
physical_query_context,
status_code,
extras,
):
physical_query_context["queries"][0]["extras"] = extras
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)

assert rv.status_code == status_code


@pytest.mark.parametrize(
"status_code,extras",
[
(200, {"where": "1 = 1"}),
(200, {"having": "count(*) > 0"}),
(200, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
(200, {"having": "count(*) > (select count(*) from physical_dataset)"}),
],
)
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_subquery_allowed(
test_client,
login_as_admin,
physical_dataset,
physical_query_context,
status_code,
extras,
):
physical_query_context["queries"][0]["extras"] = extras
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)

assert rv.status_code == status_code
99 changes: 98 additions & 1 deletion tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
from unittest.mock import patch

import pytest
from flask.ctx import AppContext
from flask.testing import FlaskClient
from sqlalchemy.engine import Engine

from superset import db
from superset.extensions import feature_flag_manager
from superset.utils.core import json_dumps_w_dates
from superset.utils.database import get_example_database, remove_database
from tests.integration_tests.test_app import app
from tests.integration_tests.test_app import app, login

if TYPE_CHECKING:
from superset.connectors.sqla.models import Database
Expand All @@ -42,6 +44,29 @@ def app_context():
yield


@pytest.fixture
def test_client(app_context: AppContext):
with app.test_client() as client:
yield client


@pytest.fixture
def login_as(test_client: "FlaskClient[Any]"):
"""Fixture with app context and logged in admin user."""

def _login_as(username: str, password: str = "general"):
login(test_client, username=username, password=password)

yield _login_as
# no need to log out as both app_context and test_client are
# function level fixtures anyway


@pytest.fixture
def login_as_admin(login_as: Callable[..., None]):
yield login_as("admin")


@pytest.fixture(autouse=True, scope="session")
def setup_sample_data() -> Any:
# TODO(john-bodley): Determine a cleaner way of setting up the sample data without
Expand Down Expand Up @@ -180,3 +205,75 @@ def wrapper(*args, **kwargs):
return functools.update_wrapper(wrapper, test_fn)

return decorate


@pytest.fixture
def physical_dataset():
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.connectors.sqla.utils import get_identifier_quoter

example_database = get_example_database()

with example_database.get_sqla_engine_with_context() as engine:
quoter = get_identifier_quoter(engine.name)
# sqlite can only execute one statement at a time
engine.execute(
f"""
CREATE TABLE IF NOT EXISTS physical_dataset(
col1 INTEGER,
col2 VARCHAR(255),
col3 DECIMAL(4,2),
col4 VARCHAR(255),
col5 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
col6 TIMESTAMP DEFAULT '1970-01-01 00:00:01',
{quoter('time column with spaces')} TIMESTAMP DEFAULT '1970-01-01 00:00:01'
);
"""
)
engine.execute(
"""
INSERT INTO physical_dataset values
(0, 'a', 1.0, NULL, '2000-01-01 00:00:00', '2002-01-03 00:00:00', '2002-01-03 00:00:00'),
(1, 'b', 1.1, NULL, '2000-01-02 00:00:00', '2002-02-04 00:00:00', '2002-02-04 00:00:00'),
(2, 'c', 1.2, NULL, '2000-01-03 00:00:00', '2002-03-07 00:00:00', '2002-03-07 00:00:00'),
(3, 'd', 1.3, NULL, '2000-01-04 00:00:00', '2002-04-12 00:00:00', '2002-04-12 00:00:00'),
(4, 'e', 1.4, NULL, '2000-01-05 00:00:00', '2002-05-11 00:00:00', '2002-05-11 00:00:00'),
(5, 'f', 1.5, NULL, '2000-01-06 00:00:00', '2002-06-13 00:00:00', '2002-06-13 00:00:00'),
(6, 'g', 1.6, NULL, '2000-01-07 00:00:00', '2002-07-15 00:00:00', '2002-07-15 00:00:00'),
(7, 'h', 1.7, NULL, '2000-01-08 00:00:00', '2002-08-18 00:00:00', '2002-08-18 00:00:00'),
(8, 'i', 1.8, NULL, '2000-01-09 00:00:00', '2002-09-20 00:00:00', '2002-09-20 00:00:00'),
(9, 'j', 1.9, NULL, '2000-01-10 00:00:00', '2002-10-22 00:00:00', '2002-10-22 00:00:00');
"""
)

dataset = SqlaTable(
table_name="physical_dataset",
database=example_database,
)
TableColumn(column_name="col1", type="INTEGER", table=dataset)
TableColumn(column_name="col2", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col3", type="DECIMAL(4,2)", table=dataset)
TableColumn(column_name="col4", type="VARCHAR(255)", table=dataset)
TableColumn(column_name="col5", type="TIMESTAMP", is_dttm=True, table=dataset)
TableColumn(column_name="col6", type="TIMESTAMP", is_dttm=True, table=dataset)
TableColumn(
column_name="time column with spaces",
type="TIMESTAMP",
is_dttm=True,
table=dataset,
)
SqlMetric(metric_name="count", expression="count(*)", table=dataset)
db.session.merge(dataset)
db.session.commit()

yield dataset

engine.execute(
"""
DROP TABLE physical_dataset;
"""
)
dataset = db.session.query(SqlaTable).filter_by(table_name="physical_dataset").all()
for ds in dataset:
db.session.delete(ds)
db.session.commit()
2 changes: 1 addition & 1 deletion tests/integration_tests/sqla_models_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def test_adhoc_metrics_and_calc_columns(self):
)
db.session.commit()

with pytest.raises(SupersetSecurityException):
with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**base_query_obj)
# Cleanup
db.session.delete(table)
Expand Down
20 changes: 16 additions & 4 deletions tests/integration_tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,23 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import TYPE_CHECKING

"""
Here is where we create the app which ends up being shared across all tests.integration_tests. A future
optimization will be to create a separate app instance for each test class.
"""
from superset.app import create_app

if TYPE_CHECKING:
from typing import Any

from flask.testing import FlaskClient

app = create_app()


def login(
client: "FlaskClient[Any]", username: str = "admin", password: str = "general"
):
resp = client.post(
"/login/",
data=dict(username=username, password=password),
).get_data(as_text=True)
assert "User confirmation needed" not in resp

0 comments on commit 7c98e26

Please sign in to comment.