From 93b270c7671ff9e5cc3bb7614954a9e308be02a9 Mon Sep 17 00:00:00 2001 From: Jesse Yang Date: Mon, 12 Oct 2020 00:18:31 -0700 Subject: [PATCH 1/4] perf: cache dashboard bootstrap data --- superset/charts/commands/delete.py | 2 + superset/config.py | 2 + superset/connectors/druid/__init__.py | 1 - superset/dashboards/dao.py | 2 +- superset/models/core.py | 6 +- superset/models/dashboard.py | 139 +++++++++++++++++++++---- superset/utils/core.py | 11 ++ superset/utils/decorators.py | 39 ++++++- superset/views/core.py | 142 ++++++++++++-------------- superset/views/dashboard/views.py | 8 +- tests/superset_test_config.py | 1 + 11 files changed, 243 insertions(+), 110 deletions(-) diff --git a/superset/charts/commands/delete.py b/superset/charts/commands/delete.py index 3feb3dbc09ff5..c6392a9055cd2 100644 --- a/superset/charts/commands/delete.py +++ b/superset/charts/commands/delete.py @@ -29,6 +29,7 @@ from superset.commands.base import BaseCommand from superset.dao.exceptions import DAODeleteFailedError from superset.exceptions import SupersetSecurityException +from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.views.base import check_ownership @@ -44,6 +45,7 @@ def __init__(self, user: User, model_id: int): def run(self) -> Model: self.validate() try: + Dashboard.clear_cache_for_slice(slice_id=self._model_id) chart = ChartDAO.delete(self._model) except DAODeleteFailedError as ex: logger.exception(ex.exception) diff --git a/superset/config.py b/superset/config.py index dce21360ea29f..9fbcbebce7d93 100644 --- a/superset/config.py +++ b/superset/config.py @@ -302,6 +302,7 @@ def _try_json_readsha( # pylint: disable=unused-argument "PRESTO_EXPAND_DATA": False, # Exposes API endpoint to compute thumbnails "THUMBNAILS": False, + "DASHBOARD_CACHE": False, "REMOVE_SLICE_LEVEL_LABEL_COLORS": False, "SHARE_QUERIES_VIA_KV_STORE": False, "SIP_38_VIZ_REARCHITECTURE": False, @@ -368,6 +369,7 @@ def _try_json_readsha( # pylint: disable=unused-argument CACHE_DEFAULT_TIMEOUT = 60 * 60 * 24 CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"} TABLE_NAMES_CACHE_CONFIG: CacheConfig = {"CACHE_TYPE": "null"} +DASHBOARD_CACHE_TIMEOUT = 60 * 60 * 24 * 365 # CORS Options ENABLE_CORS = False diff --git a/superset/connectors/druid/__init__.py b/superset/connectors/druid/__init__.py index ad52fc6d8bcb8..13a83393a9124 100644 --- a/superset/connectors/druid/__init__.py +++ b/superset/connectors/druid/__init__.py @@ -14,4 +14,3 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from . import models, views diff --git a/superset/dashboards/dao.py b/superset/dashboards/dao.py index 46ce63a4f7b9a..35db3318f35d3 100644 --- a/superset/dashboards/dao.py +++ b/superset/dashboards/dao.py @@ -108,7 +108,7 @@ def set_dash_metadata( and obj["meta"]["chartId"] ): chart_id = obj["meta"]["chartId"] - obj["meta"]["uuid"] = uuid_map[chart_id] + obj["meta"]["uuid"] = uuid_map.get(chart_id) # remove leading and trailing white spaces in the dumped json dashboard.position_json = json.dumps( diff --git a/superset/models/core.py b/superset/models/core.py index ed6f176d26e5f..36bed2a1e939f 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -54,9 +54,8 @@ from superset import app, db_engine_specs, is_feature_enabled, security_manager from superset.db_engine_specs.base import TimeGrain -from superset.models.dashboard import Dashboard from superset.models.helpers import AuditMixinNullable, ImportMixin -from superset.models.tags import DashboardUpdater, FavStarUpdater +from superset.models.tags import FavStarUpdater from superset.result_set import SupersetResultSet from superset.utils import cache as cache_util, core as utils @@ -719,8 +718,5 @@ class FavStar(Model): # pylint: disable=too-few-public-methods # events for updating tags if is_feature_enabled("TAGGING_SYSTEM"): - sqla.event.listen(Dashboard, "after_insert", DashboardUpdater.after_insert) - sqla.event.listen(Dashboard, "after_update", DashboardUpdater.after_update) - sqla.event.listen(Dashboard, "after_delete", DashboardUpdater.after_delete) sqla.event.listen(FavStar, "after_insert", FavStarUpdater.after_insert) sqla.event.listen(FavStar, "after_delete", FavStarUpdater.after_delete) diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 56ca1ba6ac656..003fc2fa282c2 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -17,8 +17,9 @@ import json import logging from copy import copy +from functools import partial from json.decoder import JSONDecodeError -from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING +from typing import Any, Callable, Dict, List, Optional, Set, Union from urllib import parse import sqlalchemy as sqla @@ -40,8 +41,20 @@ from sqlalchemy.engine.base import Connection from sqlalchemy.orm import relationship, sessionmaker, subqueryload from sqlalchemy.orm.mapper import Mapper - -from superset import app, ConnectorRegistry, db, is_feature_enabled, security_manager +from sqlalchemy.orm.session import object_session +from sqlalchemy.sql import join, select + +from superset import ( + app, + cache, + ConnectorRegistry, + db, + is_feature_enabled, + security_manager, +) +from superset.connectors.base.models import BaseDatasource +from superset.connectors.druid.models import DruidColumn, DruidMetric +from superset.connectors.sqla.models import SqlMetric, TableColumn from superset.models.helpers import AuditMixinNullable, ImportMixin from superset.models.slice import Slice from superset.models.tags import DashboardUpdater @@ -52,11 +65,9 @@ convert_filter_scopes, copy_filter_scopes, ) +from superset.utils.decorators import debounce from superset.utils.urls import get_url_path -if TYPE_CHECKING: - from superset.connectors.base.models import BaseDatasource - metadata = Model.metadata # pylint: disable=no-member config = app.config logger = logging.getLogger(__name__) @@ -131,7 +142,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes css = Column(Text) json_metadata = Column(Text) slug = Column(String(255), unique=True) - slices = relationship("Slice", secondary=dashboard_slices, backref="dashboards") + slices = relationship(Slice, secondary=dashboard_slices, backref="dashboards") owners = relationship(security_manager.user_model, secondary=dashboard_user) published = Column(Boolean, default=False) @@ -145,7 +156,7 @@ class Dashboard( # pylint: disable=too-many-instance-attributes ] def __repr__(self) -> str: - return self.dashboard_title or str(self.id) + return f"Dashboard<{self.slug or self.id}>" @property def table_names(self) -> str: @@ -177,11 +188,11 @@ def url(self) -> str: return url @property - def datasources(self) -> Set[Optional["BaseDatasource"]]: + def datasources(self) -> Set[BaseDatasource]: return {slc.datasource for slc in self.slices} @property - def charts(self) -> List[Optional["BaseDatasource"]]: + def charts(self) -> List[BaseDatasource]: return [slc.chart for slc in self.slices] @property @@ -240,6 +251,29 @@ def data(self) -> Dict[str, Any]: "last_modified_time": self.changed_on.replace(microsecond=0).timestamp(), } + @cache.memoize( + # manually maintain cache key version + make_name=lambda fname: f"{fname}-v1", + timeout=config["DASHBOARD_CACHE_TIMEOUT"], + unless=lambda: not is_feature_enabled("DASHBOARD_CACHE"), + ) + def full_data(self) -> Dict[str, Any]: + """Bootstrap data for rendering the dashboard page.""" + slices = self.slices + datasource_slices = utils.indexed(slices, "datasource") + return { + # dashboard metadata + "dashboard": self.data, + # slices data + "slices": [slc.data for slc in slices], + # datasource data + "datasources": { + # Filter out unneeded fields from the datasource payload + datasource.uid: datasource.data_for_slices(slices) + for datasource, slices in datasource_slices.items() + }, + } + @property # type: ignore def params(self) -> str: # type: ignore return self.json_metadata @@ -254,6 +288,39 @@ def position(self) -> Dict[str, Any]: return json.loads(self.position_json) return {} + def update_thumbnail(self) -> None: + url = get_url_path("Superset.dashboard", dashboard_id_or_slug=self.id) + cache_dashboard_thumbnail.delay(url, self.digest, force=True) + + @debounce(0.1) + def clear_cache(self) -> None: + cache.delete_memoized(self.full_data) + + @classmethod + @debounce(0.1) + def clear_cache_for_slice(cls, slice_id: int) -> None: + filter_query = select([dashboard_slices.c.dashboard_id], distinct=True).where( + dashboard_slices.c.slice_id == slice_id + ) + for (dashboard_id,) in db.session.execute(filter_query): + cls(id=dashboard_id).clear_cache() + + @classmethod + @debounce(0.1) + def clear_cache_for_datasource(cls, datasource_id: int) -> None: + filter_query = select( + [dashboard_slices.c.dashboard_id], distinct=True, + ).select_from( + join( + Slice, + dashboard_slices, + Slice.id == dashboard_slices.c.slice_id, + Slice.datasource_id == datasource_id, + ) + ) + for (dashboard_id,) in db.session.execute(filter_query): + cls(id=dashboard_id).clear_cache() + @classmethod def import_obj( # pylint: disable=too-many-locals,too-many-branches,too-many-statements @@ -489,12 +556,7 @@ def export_dashboards( # pylint: disable=too-many-locals ) -def event_after_dashboard_changed( - _mapper: Mapper, _connection: Connection, target: Dashboard -) -> None: - url = get_url_path("Superset.dashboard", dashboard_id_or_slug=target.id) - cache_dashboard_thumbnail.delay(url, target.digest, force=True) - +OnDashboardChange = Callable[[Mapper, Connection, Dashboard], Any] # events for updating tags if is_feature_enabled("TAGGING_SYSTEM"): @@ -502,8 +564,45 @@ def event_after_dashboard_changed( sqla.event.listen(Dashboard, "after_update", DashboardUpdater.after_update) sqla.event.listen(Dashboard, "after_delete", DashboardUpdater.after_delete) - -# events for updating tags if is_feature_enabled("THUMBNAILS_SQLA_LISTENERS"): - sqla.event.listen(Dashboard, "after_insert", event_after_dashboard_changed) - sqla.event.listen(Dashboard, "after_update", event_after_dashboard_changed) + update_thumbnail: OnDashboardChange = lambda _, __, dash: dash.update_thumbnail() + sqla.event.listen(Dashboard, "after_insert", update_thumbnail) + sqla.event.listen(Dashboard, "after_update", update_thumbnail) + +if is_feature_enabled("DASHBOARD_CACHE"): + + def clear_dashboard_cache( + _mapper: Mapper, + _connection: Connection, + obj: Union[Slice, BaseDatasource, Dashboard], + check_modified: bool = True, + ) -> None: + if check_modified and not object_session(obj).is_modified(obj): + # needed for avoiding excessive cache purging when duplicating a dashboard + return + if isinstance(obj, Dashboard): + obj.clear_cache() + elif isinstance(obj, Slice): + Dashboard.clear_cache_for_slice(slice_id=obj.id) + elif isinstance(obj, BaseDatasource): + Dashboard.clear_cache_for_datasource(datasource_id=obj.id) + elif isinstance(obj, (SqlMetric, TableColumn)): + Dashboard.clear_cache_for_datasource(datasource_id=obj.table_id) + elif isinstance(obj, (DruidMetric, DruidColumn)): + Dashboard.clear_cache_for_datasource(datasource_id=obj.datasource_id) + + sqla.event.listen(Dashboard, "after_update", clear_dashboard_cache) + sqla.event.listen( + Dashboard, "after_delete", partial(clear_dashboard_cache, check_modified=False) + ) + sqla.event.listen(Slice, "after_update", clear_dashboard_cache) + sqla.event.listen(Slice, "after_delete", clear_dashboard_cache) + sqla.event.listen( + BaseDatasource, "after_update", clear_dashboard_cache, propagage=True + ) + # also clear cache on column/metric updates since updates to these will not + # trigger update events for BaseDatasource. + sqla.event.listen(SqlMetric, "after_update", clear_dashboard_cache) + sqla.event.listen(TableColumn, "after_update", clear_dashboard_cache) + sqla.event.listen(DruidMetric, "after_update", clear_dashboard_cache) + sqla.event.listen(DruidColumn, "after_update", clear_dashboard_cache) diff --git a/superset/utils/core.py b/superset/utils/core.py index 541d24b411d07..595f4cf01a93d 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -1418,6 +1418,17 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]: return columns +def indexed( + items: List[Any], key: Union[str, Callable[[Any], Any]] +) -> Dict[Any, List[Any]]: + """Build an index for a list of objects""" + idx: Dict[Any, Any] = {} + for item in items: + key_ = getattr(item, key) if isinstance(key, str) else key(item) + idx.setdefault(key_, []).append(item) + return idx + + class LenientEnum(Enum): """Enums that do not raise ValueError when value is invalid""" diff --git a/superset/utils/decorators.py b/superset/utils/decorators.py index 694e07bd2c434..8e5e9acd13920 100644 --- a/superset/utils/decorators.py +++ b/superset/utils/decorators.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. import logging +import time from datetime import datetime, timedelta from functools import wraps -from typing import Any, Callable, Iterator +from typing import Any, Callable, Dict, Iterator, Union from contextlib2 import contextmanager from flask import request @@ -123,3 +124,39 @@ def wrapper(*args: Any, **kwargs: Any) -> ETagResponseMixin: return wrapper return decorator + + +def arghash(args: Any, kwargs: Dict[str, Any]) -> int: + """Simple argument hash with kwargs sorted.""" + sorted_args = tuple( + x if hasattr(x, "__repr__") else x for x in [*args, *sorted(kwargs.items())] + ) + return hash(sorted_args) + + +def debounce(duration: Union[float, int] = 0.1) -> Callable[..., Any]: + """Ensure a function called with the same arguments executes only once + per `duration` (default: 100ms). + """ + + def decorate(f: Callable[..., Any]) -> Callable[..., Any]: + last: Dict[str, Any] = {"t": None, "input": None, "output": None} + + def wrapped(*args: Any, **kwargs: Any) -> Any: + now = time.time() + updated_hash = arghash(args, kwargs) + if ( + last["t"] is None + or now - last["t"] >= duration + or last["input"] != updated_hash + ): + result = f(*args, **kwargs) + last["t"] = time.time() + last["input"] = updated_hash + last["output"] = result + return result + return last["output"] + + return wrapped + + return decorate diff --git a/superset/views/core.py b/superset/views/core.py index 9f16f85c36267..561aed80beff1 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -17,7 +17,6 @@ # pylint: disable=comparison-with-callable import logging import re -from collections import defaultdict from contextlib import closing from datetime import datetime from typing import Any, cast, Dict, List, Optional, Union @@ -44,7 +43,6 @@ from sqlalchemy.orm.session import Session from werkzeug.urls import Href -import superset.models.core as models from superset import ( app, appbuilder, @@ -78,7 +76,7 @@ SupersetTimeoutException, ) from superset.jinja_context import get_template_processor -from superset.models.core import Database +from superset.models.core import Database, FavStar, Log from superset.models.dashboard import Dashboard from superset.models.datasource_access_request import DatasourceAccessRequest from superset.models.slice import Slice @@ -275,7 +273,7 @@ def approve(self) -> FlaskResponse: # pylint: disable=too-many-locals,no-self-u def clean_fulfilled_requests(session: Session) -> None: for dar in session.query(DAR).all(): datasource = ConnectorRegistry.get_datasource( - dar.datasource_type, dar.datasource_id, session + dar.datasource_type, dar.datasource_id, session, ) if not datasource or security_manager.can_access_datasource(datasource): # datasource does not exist anymore @@ -752,7 +750,7 @@ def filter( # pylint: disable=no-self-use """ # TODO: Cache endpoint by user, datasource and column datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + datasource_type, datasource_id, db.session, ) if not datasource: return json_error_response(DATASOURCE_MISSING_ERR) @@ -893,7 +891,7 @@ def schemas( # pylint: disable=no-self-use "This API endpoint is deprecated and will be removed in version 1.0.0" ) db_id = int(db_id) - database = db.session.query(models.Database).get(db_id) + database = db.session.query(Database).get(db_id) if database: schemas = database.get_all_schema_names( cache=database.schema_cache_enabled, @@ -915,8 +913,8 @@ def tables( # pylint: disable=too-many-locals,no-self-use ) -> FlaskResponse: """Endpoint to fetch the list of tables for given database""" # Guarantees database filtering by security access - query = db.session.query(models.Database) - query = DatabaseFilter("id", SQLAInterface(models.Database, db.session)).apply( + query = db.session.query(Database) + query = DatabaseFilter("id", SQLAInterface(Database, db.session)).apply( query, None ) database = query.filter_by(id=db_id).one_or_none() @@ -1024,7 +1022,7 @@ def copy_dash( # pylint: disable=no-self-use # remove it to avoid confusion. data.pop("last_modified_time", None) - dash = models.Dashboard() + dash = Dashboard() original_dash = session.query(Dashboard).get(dashboard_id) dash.owners = [g.user] if g.user else [] @@ -1045,7 +1043,7 @@ def copy_dash( # pylint: disable=no-self-use for value in data["positions"].values(): if isinstance(value, dict) and value.get("meta", {}).get("chartId"): old_id = value["meta"]["chartId"] - new_id = old_to_new_slice_ids[old_id] + new_id = old_to_new_slice_ids.get(old_id) value["meta"]["chartId"] = new_id else: dash.slices = original_dash.slices @@ -1128,7 +1126,7 @@ def testconn( # pylint: disable=too-many-return-statements,no-self-use # the connection. if db_name: existing_database = ( - db.session.query(models.Database) + db.session.query(Database) .filter_by(database_name=db_name) .one_or_none() ) @@ -1137,7 +1135,7 @@ def testconn( # pylint: disable=too-many-return-statements,no-self-use # This is the database instance that will be tested. Note the extra fields # are represented as JSON encoded strings in the model. - database = models.Database( + database = Database( server_cert=request.json.get("server_cert"), extra=json.dumps(request.json.get("extra", {})), impersonate_user=request.json.get("impersonate_user"), @@ -1200,16 +1198,16 @@ def recent_activity( # pylint: disable=no-self-use limit = 1000 qry = ( - db.session.query(models.Log, models.Dashboard, Slice) - .outerjoin(models.Dashboard, models.Dashboard.id == models.Log.dashboard_id) - .outerjoin(Slice, Slice.id == models.Log.slice_id) + db.session.query(Log, Dashboard, Slice) + .outerjoin(Dashboard, Dashboard.id == Log.dashboard_id) + .outerjoin(Slice, Slice.id == Log.slice_id) .filter( and_( - ~models.Log.action.in_(("queries", "shortner", "sql_json")), - models.Log.user_id == user_id, + Log.action.in_(("queries", "shortner", "sql_json")), + Log.user_id == user_id, ) ) - .order_by(models.Log.dttm.desc()) + .order_by(Log.dttm.desc()) .limit(limit) ) payload = [] @@ -1269,16 +1267,16 @@ def fave_dashboards( # pylint: disable=no-self-use self, user_id: int ) -> FlaskResponse: qry = ( - db.session.query(Dashboard, models.FavStar.dttm) + db.session.query(Dashboard, FavStar.dttm) .join( - models.FavStar, + FavStar, and_( - models.FavStar.user_id == int(user_id), - models.FavStar.class_name == "Dashboard", - Dashboard.id == models.FavStar.obj_id, + FavStar.user_id == int(user_id), + FavStar.class_name == "Dashboard", + Dashboard.id == FavStar.obj_id, ), ) - .order_by(models.FavStar.dttm.desc()) + .order_by(FavStar.dttm.desc()) ) payload = [] for o in qry.all(): @@ -1334,7 +1332,6 @@ def user_slices( # pylint: disable=no-self-use """List of slices a user owns, created, modified or faved""" if not user_id: user_id = g.user.id - FavStar = models.FavStar owner_ids_query = ( db.session.query(Slice.id) @@ -1345,11 +1342,11 @@ def user_slices( # pylint: disable=no-self-use qry = ( db.session.query(Slice, FavStar.dttm) .join( - models.FavStar, + FavStar, and_( - models.FavStar.user_id == user_id, - models.FavStar.class_name == "slice", - Slice.id == models.FavStar.obj_id, + FavStar.user_id == user_id, + FavStar.class_name == "slice", + Slice.id == FavStar.obj_id, ), isouter=True, ) @@ -1414,16 +1411,16 @@ def fave_slices( # pylint: disable=no-self-use if not user_id: user_id = g.user.id qry = ( - db.session.query(Slice, models.FavStar.dttm) + db.session.query(Slice, FavStar.dttm) .join( - models.FavStar, + FavStar, and_( - models.FavStar.user_id == user_id, - models.FavStar.class_name == "slice", - Slice.id == models.FavStar.obj_id, + FavStar.user_id == user_id, + FavStar.class_name == "slice", + Slice.id == FavStar.obj_id, ), ) - .order_by(models.FavStar.dttm.desc()) + .order_by(FavStar.dttm.desc()) ) payload = [] for o in qry.all(): @@ -1479,9 +1476,9 @@ def warm_up_cache( # pylint: disable=too-many-locals,no-self-use elif table_name and db_name: table = ( session.query(SqlaTable) - .join(models.Database) + .join(Database) .filter( - models.Database.database_name == db_name + Database.database_name == db_name or SqlaTable.table_name == table_name ) ).one_or_none() @@ -1541,7 +1538,6 @@ def favstar( # pylint: disable=no-self-use ) -> FlaskResponse: """Toggle favorite stars on Slices and Dashboard""" session = db.session() - FavStar = models.FavStar count = 0 favs = ( session.query(FavStar) @@ -1621,16 +1617,17 @@ def dashboard( # pylint: disable=too-many-locals if not dash: abort(404) - datasources = defaultdict(list) - for slc in dash.slices: - datasource = slc.datasource - if datasource: - datasources[datasource].append(slc) + data = dash.full_data() if config["ENABLE_ACCESS_REQUEST"]: - for datasource in datasources: + for datasource in data["datasources"].values(): + datasource = ConnectorRegistry.get_datasource( + datasource_type=datasource["type"], + datasource_id=datasource["id"], + session=session, + ) if datasource and not security_manager.can_access_datasource( - datasource + datasource=datasource ): flash( __( @@ -1638,15 +1635,7 @@ def dashboard( # pylint: disable=too-many-locals ), "danger", ) - return redirect( - "superset/request_access/?" f"dashboard_id={dash.id}&" - ) - - # Filter out unneeded fields from the datasource payload - datasources_payload = { - datasource.uid: datasource.data_for_slices(slices) - for datasource, slices in datasources.items() - } + return redirect(f"/superset/request_access/?dashboard_id={dash.id}") dash_edit_perm = check_ownership( dash, raise_if_false=False @@ -1675,24 +1664,13 @@ def dashboard(**_: Any) -> None: edit_mode=edit_mode, ) - dashboard_data = dash.data if is_feature_enabled("REMOVE_SLICE_LEVEL_LABEL_COLORS"): # dashboard metadata has dashboard-level label_colors, # so remove slice-level label_colors from its form_data - for slc in dashboard_data.get("slices"): + for slc in data["slices"]: form_data = slc.get("form_data") form_data.pop("label_colors", None) - dashboard_data.update( - { - "standalone_mode": standalone_mode, - "dash_save_perm": dash_save_perm, - "dash_edit_perm": dash_edit_perm, - "superset_can_explore": superset_can_explore, - "superset_can_csv": superset_can_csv, - "slice_can_edit": slice_can_edit, - } - ) url_params = { key: value for key, value in request.args.items() @@ -1701,11 +1679,19 @@ def dashboard(**_: Any) -> None: bootstrap_data = { "user_id": g.user.get_id(), - "dashboard_data": dashboard_data, - "datasources": datasources_payload, "common": common_bootstrap_payload(), "editMode": edit_mode, "urlParams": url_params, + "dashboard_data": { + **data["dashboard"], + "standalone_mode": standalone_mode, + "dash_save_perm": dash_save_perm, + "dash_edit_perm": dash_edit_perm, + "superset_can_explore": superset_can_explore, + "superset_can_csv": superset_can_csv, + "slice_can_edit": slice_can_edit, + }, + "datasources": data["datasources"], } if request.args.get("json") == "true": @@ -1718,7 +1704,7 @@ def dashboard(**_: Any) -> None: entry="dashboard", standalone_mode=standalone_mode, title=dash.dashboard_title, - custom_css=dashboard_data.get("css"), + custom_css=dash.css, bootstrap_data=json.dumps( bootstrap_data, default=utils.pessimistic_json_iso_dttm_ser ), @@ -1819,7 +1805,7 @@ def sqllab_table_viz(self) -> FlaskResponse: # pylint: disable=no-self-use table = SqlaTable(table_name=table_name, owners=[g.user]) table.database_id = database_id table.database = ( - db.session.query(models.Database).filter_by(id=database_id).one() + db.session.query(Database).filter_by(id=database_id).one() ) table.schema = data.get("schema") table.template_params = data.get("templateParams") @@ -1884,7 +1870,7 @@ def extra_table_metadata( # pylint: disable=no-self-use ) -> FlaskResponse: parsed_schema = utils.parse_js_uri_path_item(schema, eval_undefined=True) table_name = utils.parse_js_uri_path_item(table_name) # type: ignore - mydb = db.session.query(models.Database).filter_by(id=database_id).one() + mydb = db.session.query(Database).filter_by(id=database_id).one() payload = mydb.db_engine_spec.extra_table_metadata( mydb, table_name, parsed_schema ) @@ -1903,7 +1889,7 @@ def select_star( self.__class__.__name__, ) stats_logger.incr(f"{self.__class__.__name__}.select_star.init") - database = db.session.query(models.Database).get(database_id) + database = db.session.query(Database).get(database_id) if not database: stats_logger.incr( f"deprecated.{self.__class__.__name__}.select_star.database_not_found" @@ -1936,7 +1922,7 @@ def select_star( def estimate_query_cost( # pylint: disable=no-self-use self, database_id: int, schema: Optional[str] = None ) -> FlaskResponse: - mydb = db.session.query(models.Database).get(database_id) + mydb = db.session.query(Database).get(database_id) sql = json.loads(request.form.get("sql", '""')) template_params = json.loads(request.form.get("templateParams") or "{}") @@ -2090,7 +2076,7 @@ def validate_sql_json( # pylint: disable=too-many-locals,too-many-return-statem ) session = db.session() - mydb = session.query(models.Database).filter_by(id=database_id).one_or_none() + mydb = session.query(Database).filter_by(id=database_id).one_or_none() if not mydb: return json_error_response( "Database with id {} is missing.".format(database_id), status=400 @@ -2289,7 +2275,7 @@ def sql_json_exec( # pylint: disable=too-many-statements,too-many-locals status: str = QueryStatus.PENDING if async_flag else QueryStatus.RUNNING session = db.session() - mydb = session.query(models.Database).get(database_id) + mydb = session.query(Database).get(database_id) if not mydb: return json_error_response("Database with id %i is missing.", database_id) @@ -2446,7 +2432,7 @@ def fetch_datasource_metadata(self) -> FlaskResponse: # pylint: disable=no-self datasource_id, datasource_type = request.args["datasourceKey"].split("__") datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, db.session + datasource_type, datasource_id, db.session, ) # Check if datasource exists if not datasource: @@ -2639,7 +2625,7 @@ def _get_sqllab_tabs(user_id: int) -> Dict[str, Any]: database.id: { k: v for k, v in database.to_json().items() if k in DATABASE_KEYS } - for database in db.session.query(models.Database).all() + for database in db.session.query(Database).all() } # return all user queries associated with existing SQL editors user_queries = ( @@ -2695,7 +2681,7 @@ def schemas_access_for_csv_upload(self) -> FlaskResponse: return json_error_response("No database is allowed for your csv upload") db_id = int(request.args["db_id"]) - database = db.session.query(models.Database).filter_by(id=db_id).one() + database = db.session.query(Database).filter_by(id=db_id).one() try: schemas_allowed = database.get_schema_access_for_csv_upload() if security_manager.can_access_database(database): diff --git a/superset/views/dashboard/views.py b/superset/views/dashboard/views.py index a446bf37b9377..b92c4ce002b8c 100644 --- a/superset/views/dashboard/views.py +++ b/superset/views/dashboard/views.py @@ -24,9 +24,9 @@ from flask_appbuilder.security.decorators import has_access from flask_babel import gettext as __, lazy_gettext as _ -import superset.models.core as models from superset import app, db, event_logger from superset.constants import RouteMethod +from superset.models.dashboard import Dashboard as DashboardModel from superset.typing import FlaskResponse from superset.utils import core as utils from superset.views.base import ( @@ -43,7 +43,7 @@ class DashboardModelView( DashboardMixin, SupersetModelView, DeleteMixin ): # pylint: disable=too-many-ancestors route_base = "/dashboard" - datamodel = SQLAInterface(models.Dashboard) + datamodel = SQLAInterface(DashboardModel) # TODO disable api_read and api_delete (used by cypress) # once we move to ChartRestModelApi include_route_methods = RouteMethod.CRUD_SET | { @@ -76,7 +76,7 @@ def download_dashboards(self) -> FlaskResponse: if request.args.get("action") == "go": ids = request.args.getlist("id") return Response( - models.Dashboard.export_dashboards(ids), + DashboardModel.export_dashboards(ids), headers=generate_download_headers("json"), mimetype="application/text", ) @@ -110,7 +110,7 @@ class Dashboard(BaseSupersetView): @expose("/new/") def new(self) -> FlaskResponse: # pylint: disable=no-self-use """Creates a new, blank dashboard and redirects to it in edit mode""" - new_dashboard = models.Dashboard( + new_dashboard = DashboardModel( dashboard_title="[ untitled dashboard ]", owners=[g.user] ) db.session.add(new_dashboard) diff --git a/tests/superset_test_config.py b/tests/superset_test_config.py index 8d03115bde2db..9e3ff65f31860 100644 --- a/tests/superset_test_config.py +++ b/tests/superset_test_config.py @@ -50,6 +50,7 @@ SQL_MAX_ROW = 666 SQLLAB_CTAS_NO_LIMIT = True # SQL_MAX_ROW will not take affect for the CTA queries FEATURE_FLAGS = { + **FEATURE_FLAGS, "foo": "bar", "KV_STORE": True, "SHARE_QUERIES_VIA_KV_STORE": True, From 7a1ac1840138b53542c58ebadf6d9139aafc6a6d Mon Sep 17 00:00:00 2001 From: Jesse Yang Date: Tue, 13 Oct 2020 13:23:59 -0700 Subject: [PATCH 2/4] Make it clear returned values are metadata --- superset/models/dashboard.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index 003fc2fa282c2..c21fb3a7a4449 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -264,9 +264,9 @@ def full_data(self) -> Dict[str, Any]: return { # dashboard metadata "dashboard": self.data, - # slices data + # slices metadata "slices": [slc.data for slc in slices], - # datasource data + # datasource metadata "datasources": { # Filter out unneeded fields from the datasource payload datasource.uid: datasource.data_for_slices(slices) From 7939b0a73c88cade300b288cb8a1106cca85aacc Mon Sep 17 00:00:00 2001 From: Jesse Yang Date: Tue, 13 Oct 2020 16:23:02 -0700 Subject: [PATCH 3/4] Add basic test case for debounce --- tests/{util => utils}/__init__.py | 0 tests/utils/decorators_tests.py | 43 +++++++++++++++++++++ tests/{util => utils}/machine_auth_tests.py | 0 3 files changed, 43 insertions(+) rename tests/{util => utils}/__init__.py (100%) create mode 100644 tests/utils/decorators_tests.py rename tests/{util => utils}/machine_auth_tests.py (100%) diff --git a/tests/util/__init__.py b/tests/utils/__init__.py similarity index 100% rename from tests/util/__init__.py rename to tests/utils/__init__.py diff --git a/tests/utils/decorators_tests.py b/tests/utils/decorators_tests.py new file mode 100644 index 0000000000000..da76f19386a3e --- /dev/null +++ b/tests/utils/decorators_tests.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from unittest.mock import Mock + +from superset.utils import decorators +from tests.base_tests import SupersetTestCase + + +class UtilsDecoratorsTests(SupersetTestCase): + def test_debounce(self): + mock = Mock() + + @decorators.debounce() + def myfunc(arg1: int, arg2: int, kwarg1: str = "abc", kwarg2: int = 2): + mock(arg1, kwarg1) + return arg1 + arg2 + kwarg2 + + # should be called only once when arguments don't change + myfunc(1, 1) + myfunc(1, 1) + result = myfunc(1, 1) + mock.assert_called_once_with(1, "abc") + self.assertEqual(result, 4) + + # kwarg order shouldn't matter + myfunc(1, 0, kwarg2=2, kwarg1="haha") + result = myfunc(1, 0, kwarg1="haha", kwarg2=2) + mock.assert_has_calls([call(1, "abc"), call(1, "haha")]) + self.assertEqual(result, 3) diff --git a/tests/util/machine_auth_tests.py b/tests/utils/machine_auth_tests.py similarity index 100% rename from tests/util/machine_auth_tests.py rename to tests/utils/machine_auth_tests.py From 985fdb050f2473efbdfa75e3071fafd8d52713c2 Mon Sep 17 00:00:00 2001 From: Jesse Yang Date: Tue, 13 Oct 2020 16:29:23 -0700 Subject: [PATCH 4/4] fix linting --- tests/utils.py | 29 ----------------------------- tests/utils/__init__.py | 13 +++++++++++++ tests/utils/decorators_tests.py | 2 +- 3 files changed, 14 insertions(+), 30 deletions(-) delete mode 100644 tests/utils.py diff --git a/tests/utils.py b/tests/utils.py deleted file mode 100644 index 832ddd5d886dd..0000000000000 --- a/tests/utils.py +++ /dev/null @@ -1,29 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -import json -from os import path - -FIXTURES_DIR = "tests/fixtures" - - -def read_fixture(fixture_file_name): - with open(path.join(FIXTURES_DIR, fixture_file_name), "rb") as fixture_file: - return fixture_file.read() - - -def load_fixture(fixture_file_name): - return json.loads(read_fixture(fixture_file_name)) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index 13a83393a9124..832ddd5d886dd 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -14,3 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +import json +from os import path + +FIXTURES_DIR = "tests/fixtures" + + +def read_fixture(fixture_file_name): + with open(path.join(FIXTURES_DIR, fixture_file_name), "rb") as fixture_file: + return fixture_file.read() + + +def load_fixture(fixture_file_name): + return json.loads(read_fixture(fixture_file_name)) diff --git a/tests/utils/decorators_tests.py b/tests/utils/decorators_tests.py index da76f19386a3e..84812546926f4 100644 --- a/tests/utils/decorators_tests.py +++ b/tests/utils/decorators_tests.py @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from unittest.mock import Mock +from unittest.mock import call, Mock from superset.utils import decorators from tests.base_tests import SupersetTestCase