From 13e32dd26b5eaaad33e2463e6c7e5b25b0f9885d Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 28 Oct 2021 14:18:27 -0700 Subject: [PATCH] fix: set correct schema on config import --- superset/commands/importers/v1/examples.py | 8 ++++++ .../datasets/commands/importers/v1/utils.py | 15 ++++++++++- superset/examples/bart_lines.py | 8 ++++-- superset/examples/birth_names.py | 25 +++++++++++-------- superset/examples/country_map.py | 8 ++++-- superset/examples/energy.py | 8 ++++-- superset/examples/flights.py | 8 ++++-- superset/examples/long_lat.py | 8 ++++-- superset/examples/multiformat_time_series.py | 8 ++++-- superset/examples/paris.py | 8 ++++-- superset/examples/random_time_series.py | 8 ++++-- superset/examples/sf_population_polygons.py | 8 ++++-- superset/examples/world_bank.py | 8 ++++-- 13 files changed, 97 insertions(+), 31 deletions(-) diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py index 21580fb39e5af..05682e67bd63f 100644 --- a/superset/commands/importers/v1/examples.py +++ b/superset/commands/importers/v1/examples.py @@ -17,6 +17,7 @@ from typing import Any, Dict, List, Set, Tuple from marshmallow import Schema +from sqlalchemy import inspect from sqlalchemy.orm import Session from sqlalchemy.orm.exc import MultipleResultsFound from sqlalchemy.sql import select @@ -114,6 +115,13 @@ def _import( # pylint: disable=arguments-differ,too-many-locals else: config["database_id"] = database_ids[config["database_uuid"]] + # set schema + if config["schema"] is None: + database = get_example_database() + engine = database.get_sqla_engine() + insp = inspect(engine) + config["schema"] = insp.default_schema_name + dataset = import_dataset( session, config, overwrite=overwrite, force_data=force_data ) diff --git a/superset/datasets/commands/importers/v1/utils.py b/superset/datasets/commands/importers/v1/utils.py index 78cfae51ba6ed..37522da28c2d2 100644 --- a/superset/datasets/commands/importers/v1/utils.py +++ b/superset/datasets/commands/importers/v1/utils.py @@ -25,6 +25,7 @@ from flask import current_app, g from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text from sqlalchemy.orm import Session +from sqlalchemy.orm.exc import MultipleResultsFound from sqlalchemy.sql.visitors import VisitableType from superset.connectors.sqla.models import SqlaTable @@ -110,7 +111,19 @@ def import_dataset( data_uri = config.get("data") # import recursively to include columns and metrics - dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync) + try: + dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync) + except MultipleResultsFound: + # Finding multiple results when importing a dataset only happens because initially + # datasets were imported without schemas (eg, `examples.NULL.users`), and later + # they were fixed to have the default schema (eg, `examples.public.users`). If a + # user created `examples.public.users` during that time the second import will + # fail because the UUID match will try to update `examples.NULL.users` to + # `examples.public.users`, resulting in a conflict. + # + # When that happens, we return the original dataset, unmodified. + dataset = session.query(SqlaTable).filter_by(uuid=config["uuid"]).one() + if dataset.id is None: session.flush() diff --git a/superset/examples/bart_lines.py b/superset/examples/bart_lines.py index 8cdb8a3bdee8b..ccc417725e16c 100644 --- a/superset/examples/bart_lines.py +++ b/superset/examples/bart_lines.py @@ -18,7 +18,7 @@ import pandas as pd import polyline -from sqlalchemy import String, Text +from sqlalchemy import inspect, String, Text from superset import db from superset.utils.core import get_example_database @@ -29,6 +29,8 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None: tbl_name = "bart_lines" database = get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): @@ -40,7 +42,8 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None: df.to_sql( tbl_name, - database.get_sqla_engine(), + engine, + schema=schema, if_exists="replace", chunksize=500, dtype={ @@ -59,6 +62,7 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None: tbl = table(table_name=tbl_name) tbl.description = "BART lines" tbl.database = database + tbl.schema = schema tbl.filter_select_enabled = True db.session.merge(tbl) db.session.commit() diff --git a/superset/examples/birth_names.py b/superset/examples/birth_names.py index 2fc1fae8c037e..f4e4937344eec 100644 --- a/superset/examples/birth_names.py +++ b/superset/examples/birth_names.py @@ -20,12 +20,11 @@ import pandas as pd from flask_appbuilder.security.sqla.models import User -from sqlalchemy import DateTime, String +from sqlalchemy import DateTime, inspect, String from sqlalchemy.sql import column from superset import app, db, security_manager -from superset.connectors.base.models import BaseDatasource -from superset.connectors.sqla.models import SqlMetric, TableColumn +from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.exceptions import NoDataException from superset.models.core import Database from superset.models.dashboard import Dashboard @@ -75,9 +74,13 @@ def load_data(tbl_name: str, database: Database, sample: bool = False) -> None: pdf.ds = pd.to_datetime(pdf.ds, unit="ms") pdf = pdf.head(100) if sample else pdf + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name + pdf.to_sql( tbl_name, database.get_sqla_engine(), + schema=schema, if_exists="replace", chunksize=500, dtype={ @@ -121,14 +124,18 @@ def load_birth_names( create_dashboard(slices) -def _set_table_metadata(datasource: "BaseDatasource", database: "Database") -> None: - datasource.main_dttm_col = "ds" # type: ignore +def _set_table_metadata(datasource: SqlaTable, database: "Database") -> None: + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name + + datasource.main_dttm_col = "ds" datasource.database = database + datasource.schema = schema datasource.filter_select_enabled = True datasource.fetch_metadata() -def _add_table_metrics(datasource: "BaseDatasource") -> None: +def _add_table_metrics(datasource: SqlaTable) -> None: if not any(col.column_name == "num_california" for col in datasource.columns): col_state = str(column("state").compile(db.engine)) col_num = str(column("num").compile(db.engine)) @@ -147,13 +154,11 @@ def _add_table_metrics(datasource: "BaseDatasource") -> None: for col in datasource.columns: if col.column_name == "ds": - col.is_dttm = True # type: ignore + col.is_dttm = True break -def create_slices( - tbl: BaseDatasource, admin_owner: bool -) -> Tuple[List[Slice], List[Slice]]: +def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[Slice]]: metrics = [ { "expressionType": "SIMPLE", diff --git a/superset/examples/country_map.py b/superset/examples/country_map.py index 4ed5235e6d91c..535b7bff37544 100644 --- a/superset/examples/country_map.py +++ b/superset/examples/country_map.py @@ -17,7 +17,7 @@ import datetime import pandas as pd -from sqlalchemy import BigInteger, Date, String +from sqlalchemy import BigInteger, Date, inspect, String from sqlalchemy.sql import column from superset import db @@ -38,6 +38,8 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N """Loading data for map with country map""" tbl_name = "birth_france_by_region" database = utils.get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): @@ -48,7 +50,8 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N data["dttm"] = datetime.datetime.now().date() data.to_sql( tbl_name, - database.get_sqla_engine(), + engine, + schema=schema, if_exists="replace", chunksize=500, dtype={ @@ -79,6 +82,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N obj = table(table_name=tbl_name) obj.main_dttm_col = "dttm" obj.database = database + obj.schema = schema obj.filter_select_enabled = True if not any(col.metric_name == "avg__2004" for col in obj.metrics): col = str(column("2004").compile(db.engine)) diff --git a/superset/examples/energy.py b/superset/examples/energy.py index 4ad56b020da0d..26e20d7dc1f8b 100644 --- a/superset/examples/energy.py +++ b/superset/examples/energy.py @@ -18,7 +18,7 @@ import textwrap import pandas as pd -from sqlalchemy import Float, String +from sqlalchemy import Float, inspect, String from sqlalchemy.sql import column from superset import db @@ -40,6 +40,8 @@ def load_energy( """Loads an energy related dataset to use with sankey and graphs""" tbl_name = "energy_usage" database = utils.get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): @@ -48,7 +50,8 @@ def load_energy( pdf = pdf.head(100) if sample else pdf pdf.to_sql( tbl_name, - database.get_sqla_engine(), + engine, + schema=schema, if_exists="replace", chunksize=500, dtype={"source": String(255), "target": String(255), "value": Float()}, @@ -63,6 +66,7 @@ def load_energy( tbl = table(table_name=tbl_name) tbl.description = "Energy consumption" tbl.database = database + tbl.schema = schema tbl.filter_select_enabled = True if not any(col.metric_name == "sum__value" for col in tbl.metrics): diff --git a/superset/examples/flights.py b/superset/examples/flights.py index cb72940f60526..fe5d0e7aa0733 100644 --- a/superset/examples/flights.py +++ b/superset/examples/flights.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. import pandas as pd -from sqlalchemy import DateTime +from sqlalchemy import DateTime, inspect from superset import db from superset.utils import core as utils @@ -27,6 +27,8 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None: """Loading random time series data from a zip file in the repo""" tbl_name = "flights" database = utils.get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): @@ -47,7 +49,8 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None: pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST") pdf.to_sql( tbl_name, - database.get_sqla_engine(), + engine, + schema=schema, if_exists="replace", chunksize=500, dtype={"ds": DateTime}, @@ -60,6 +63,7 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None: tbl = table(table_name=tbl_name) tbl.description = "Random set of flights in the US" tbl.database = database + tbl.schema = schema tbl.filter_select_enabled = True db.session.merge(tbl) db.session.commit() diff --git a/superset/examples/long_lat.py b/superset/examples/long_lat.py index 7e2f2f9bdc206..3284d66135c9b 100644 --- a/superset/examples/long_lat.py +++ b/superset/examples/long_lat.py @@ -19,7 +19,7 @@ import geohash import pandas as pd -from sqlalchemy import DateTime, Float, String +from sqlalchemy import DateTime, Float, inspect, String from superset import db from superset.models.slice import Slice @@ -38,6 +38,8 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None """Loading lat/long data from a csv file in the repo""" tbl_name = "long_lat" database = utils.get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): @@ -56,7 +58,8 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",") pdf.to_sql( tbl_name, - database.get_sqla_engine(), + engine, + schema=schema, if_exists="replace", chunksize=500, dtype={ @@ -88,6 +91,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None obj = table(table_name=tbl_name) obj.main_dttm_col = "datetime" obj.database = database + obj.schema = schema obj.filter_select_enabled = True db.session.merge(obj) db.session.commit() diff --git a/superset/examples/multiformat_time_series.py b/superset/examples/multiformat_time_series.py index e473ec8c3843a..2c2bca81b1846 100644 --- a/superset/examples/multiformat_time_series.py +++ b/superset/examples/multiformat_time_series.py @@ -17,7 +17,7 @@ from typing import Dict, Optional, Tuple import pandas as pd -from sqlalchemy import BigInteger, Date, DateTime, String +from sqlalchemy import BigInteger, Date, DateTime, inspect, String from superset import app, db from superset.models.slice import Slice @@ -38,6 +38,8 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals """Loading time series data from a zip file in the repo""" tbl_name = "multiformat_time_series" database = get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): @@ -55,7 +57,8 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals pdf.to_sql( tbl_name, - database.get_sqla_engine(), + engine, + schema=schema, if_exists="replace", chunksize=500, dtype={ @@ -80,6 +83,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals obj = table(table_name=tbl_name) obj.main_dttm_col = "ds" obj.database = database + obj.schema = schema obj.filter_select_enabled = True dttm_and_expr_dict: Dict[str, Tuple[Optional[str], None]] = { "ds": (None, None), diff --git a/superset/examples/paris.py b/superset/examples/paris.py index 2c16bcee485d3..dc51402ed8a63 100644 --- a/superset/examples/paris.py +++ b/superset/examples/paris.py @@ -17,7 +17,7 @@ import json import pandas as pd -from sqlalchemy import String, Text +from sqlalchemy import inspect, String, Text from superset import db from superset.utils import core as utils @@ -28,6 +28,8 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> None: tbl_name = "paris_iris_mapping" database = utils.get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): @@ -37,7 +39,8 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> df.to_sql( tbl_name, - database.get_sqla_engine(), + engine, + schema=schema, if_exists="replace", chunksize=500, dtype={ @@ -56,6 +59,7 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> tbl = table(table_name=tbl_name) tbl.description = "Map of Paris" tbl.database = database + tbl.schema = schema tbl.filter_select_enabled = True db.session.merge(tbl) db.session.commit() diff --git a/superset/examples/random_time_series.py b/superset/examples/random_time_series.py index 394e895a886a6..8adba3f00d918 100644 --- a/superset/examples/random_time_series.py +++ b/superset/examples/random_time_series.py @@ -16,7 +16,7 @@ # under the License. import pandas as pd -from sqlalchemy import DateTime, String +from sqlalchemy import DateTime, inspect, String from superset import app, db from superset.models.slice import Slice @@ -36,6 +36,8 @@ def load_random_time_series_data( """Loading random time series data from a zip file in the repo""" tbl_name = "random_time_series" database = utils.get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): @@ -49,7 +51,8 @@ def load_random_time_series_data( pdf.to_sql( tbl_name, - database.get_sqla_engine(), + engine, + schema=schema, if_exists="replace", chunksize=500, dtype={"ds": DateTime if database.backend != "presto" else String(255)}, @@ -65,6 +68,7 @@ def load_random_time_series_data( obj = table(table_name=tbl_name) obj.main_dttm_col = "ds" obj.database = database + obj.schema = schema obj.filter_select_enabled = True db.session.merge(obj) db.session.commit() diff --git a/superset/examples/sf_population_polygons.py b/superset/examples/sf_population_polygons.py index 426822c72f604..c4e97ae3f5c96 100644 --- a/superset/examples/sf_population_polygons.py +++ b/superset/examples/sf_population_polygons.py @@ -17,7 +17,7 @@ import json import pandas as pd -from sqlalchemy import BigInteger, Float, Text +from sqlalchemy import BigInteger, Float, inspect, Text from superset import db from superset.utils import core as utils @@ -30,6 +30,8 @@ def load_sf_population_polygons( ) -> None: tbl_name = "sf_population_polygons" database = utils.get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): @@ -39,7 +41,8 @@ def load_sf_population_polygons( df.to_sql( tbl_name, - database.get_sqla_engine(), + engine, + schema=schema, if_exists="replace", chunksize=500, dtype={ @@ -58,6 +61,7 @@ def load_sf_population_polygons( tbl = table(table_name=tbl_name) tbl.description = "Population density of San Francisco" tbl.database = database + tbl.schema = schema tbl.filter_select_enabled = True db.session.merge(tbl) db.session.commit() diff --git a/superset/examples/world_bank.py b/superset/examples/world_bank.py index 83d710a2be716..8e320774d2f9d 100644 --- a/superset/examples/world_bank.py +++ b/superset/examples/world_bank.py @@ -20,7 +20,7 @@ from typing import List import pandas as pd -from sqlalchemy import DateTime, String +from sqlalchemy import DateTime, inspect, String from sqlalchemy.sql import column from superset import app, db @@ -47,6 +47,8 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals """Loads the world bank health dataset, slices and a dashboard""" tbl_name = "wb_health_population" database = utils.get_example_database() + engine = database.get_sqla_engine() + schema = inspect(engine).default_schema_name table_exists = database.has_table_by_name(tbl_name) if not only_metadata and (not table_exists or force): @@ -62,7 +64,8 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals pdf.to_sql( tbl_name, - database.get_sqla_engine(), + engine, + schema=schema, if_exists="replace", chunksize=50, dtype={ @@ -86,6 +89,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals ) tbl.main_dttm_col = "year" tbl.database = database + tbl.schema = schema tbl.filter_select_enabled = True metrics = [