Skip to content

Commit

Permalink
fix: set correct schema on config import
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida committed Nov 1, 2021
1 parent bea8502 commit 13e32dd
Show file tree
Hide file tree
Showing 13 changed files with 97 additions and 31 deletions.
8 changes: 8 additions & 0 deletions superset/commands/importers/v1/examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Any, Dict, List, Set, Tuple

from marshmallow import Schema
from sqlalchemy import inspect
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy.sql import select
Expand Down Expand Up @@ -114,6 +115,13 @@ def _import( # pylint: disable=arguments-differ,too-many-locals
else:
config["database_id"] = database_ids[config["database_uuid"]]

# set schema
if config["schema"] is None:
database = get_example_database()
engine = database.get_sqla_engine()
insp = inspect(engine)
config["schema"] = insp.default_schema_name

dataset = import_dataset(
session, config, overwrite=overwrite, force_data=force_data
)
Expand Down
15 changes: 14 additions & 1 deletion superset/datasets/commands/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from flask import current_app, g
from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text
from sqlalchemy.orm import Session
from sqlalchemy.orm.exc import MultipleResultsFound
from sqlalchemy.sql.visitors import VisitableType

from superset.connectors.sqla.models import SqlaTable
Expand Down Expand Up @@ -110,7 +111,19 @@ def import_dataset(
data_uri = config.get("data")

# import recursively to include columns and metrics
dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync)
try:
dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync)
except MultipleResultsFound:
# Finding multiple results when importing a dataset only happens because initially
# datasets were imported without schemas (eg, `examples.NULL.users`), and later
# they were fixed to have the default schema (eg, `examples.public.users`). If a
# user created `examples.public.users` during that time the second import will
# fail because the UUID match will try to update `examples.NULL.users` to
# `examples.public.users`, resulting in a conflict.
#
# When that happens, we return the original dataset, unmodified.
dataset = session.query(SqlaTable).filter_by(uuid=config["uuid"]).one()

if dataset.id is None:
session.flush()

Expand Down
8 changes: 6 additions & 2 deletions superset/examples/bart_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

import pandas as pd
import polyline
from sqlalchemy import String, Text
from sqlalchemy import inspect, String, Text

from superset import db
from superset.utils.core import get_example_database
Expand All @@ -29,6 +29,8 @@
def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl_name = "bart_lines"
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -40,7 +42,8 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:

df.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
Expand All @@ -59,6 +62,7 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl = table(table_name=tbl_name)
tbl.description = "BART lines"
tbl.database = database
tbl.schema = schema
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
Expand Down
25 changes: 15 additions & 10 deletions superset/examples/birth_names.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,11 @@

import pandas as pd
from flask_appbuilder.security.sqla.models import User
from sqlalchemy import DateTime, String
from sqlalchemy import DateTime, inspect, String
from sqlalchemy.sql import column

from superset import app, db, security_manager
from superset.connectors.base.models import BaseDatasource
from superset.connectors.sqla.models import SqlMetric, TableColumn
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.exceptions import NoDataException
from superset.models.core import Database
from superset.models.dashboard import Dashboard
Expand Down Expand Up @@ -75,9 +74,13 @@ def load_data(tbl_name: str, database: Database, sample: bool = False) -> None:
pdf.ds = pd.to_datetime(pdf.ds, unit="ms")
pdf = pdf.head(100) if sample else pdf

engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name

pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
Expand Down Expand Up @@ -121,14 +124,18 @@ def load_birth_names(
create_dashboard(slices)


def _set_table_metadata(datasource: "BaseDatasource", database: "Database") -> None:
datasource.main_dttm_col = "ds" # type: ignore
def _set_table_metadata(datasource: SqlaTable, database: "Database") -> None:
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name

datasource.main_dttm_col = "ds"
datasource.database = database
datasource.schema = schema
datasource.filter_select_enabled = True
datasource.fetch_metadata()


def _add_table_metrics(datasource: "BaseDatasource") -> None:
def _add_table_metrics(datasource: SqlaTable) -> None:
if not any(col.column_name == "num_california" for col in datasource.columns):
col_state = str(column("state").compile(db.engine))
col_num = str(column("num").compile(db.engine))
Expand All @@ -147,13 +154,11 @@ def _add_table_metrics(datasource: "BaseDatasource") -> None:

for col in datasource.columns:
if col.column_name == "ds":
col.is_dttm = True # type: ignore
col.is_dttm = True
break


def create_slices(
tbl: BaseDatasource, admin_owner: bool
) -> Tuple[List[Slice], List[Slice]]:
def create_slices(tbl: SqlaTable, admin_owner: bool) -> Tuple[List[Slice], List[Slice]]:
metrics = [
{
"expressionType": "SIMPLE",
Expand Down
8 changes: 6 additions & 2 deletions superset/examples/country_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import datetime

import pandas as pd
from sqlalchemy import BigInteger, Date, String
from sqlalchemy import BigInteger, Date, inspect, String
from sqlalchemy.sql import column

from superset import db
Expand All @@ -38,6 +38,8 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
"""Loading data for map with country map"""
tbl_name = "birth_france_by_region"
database = utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -48,7 +50,8 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
data["dttm"] = datetime.datetime.now().date()
data.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
Expand Down Expand Up @@ -79,6 +82,7 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
obj = table(table_name=tbl_name)
obj.main_dttm_col = "dttm"
obj.database = database
obj.schema = schema
obj.filter_select_enabled = True
if not any(col.metric_name == "avg__2004" for col in obj.metrics):
col = str(column("2004").compile(db.engine))
Expand Down
8 changes: 6 additions & 2 deletions superset/examples/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import textwrap

import pandas as pd
from sqlalchemy import Float, String
from sqlalchemy import Float, inspect, String
from sqlalchemy.sql import column

from superset import db
Expand All @@ -40,6 +40,8 @@ def load_energy(
"""Loads an energy related dataset to use with sankey and graphs"""
tbl_name = "energy_usage"
database = utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -48,7 +50,8 @@ def load_energy(
pdf = pdf.head(100) if sample else pdf
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={"source": String(255), "target": String(255), "value": Float()},
Expand All @@ -63,6 +66,7 @@ def load_energy(
tbl = table(table_name=tbl_name)
tbl.description = "Energy consumption"
tbl.database = database
tbl.schema = schema
tbl.filter_select_enabled = True

if not any(col.metric_name == "sum__value" for col in tbl.metrics):
Expand Down
8 changes: 6 additions & 2 deletions superset/examples/flights.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# specific language governing permissions and limitations
# under the License.
import pandas as pd
from sqlalchemy import DateTime
from sqlalchemy import DateTime, inspect

from superset import db
from superset.utils import core as utils
Expand All @@ -27,6 +27,8 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
"""Loading random time series data from a zip file in the repo"""
tbl_name = "flights"
database = utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -47,7 +49,8 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
pdf = pdf.join(airports, on="DESTINATION_AIRPORT", rsuffix="_DEST")
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={"ds": DateTime},
Expand All @@ -60,6 +63,7 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
tbl = table(table_name=tbl_name)
tbl.description = "Random set of flights in the US"
tbl.database = database
tbl.schema = schema
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
Expand Down
8 changes: 6 additions & 2 deletions superset/examples/long_lat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

import geohash
import pandas as pd
from sqlalchemy import DateTime, Float, String
from sqlalchemy import DateTime, Float, inspect, String

from superset import db
from superset.models.slice import Slice
Expand All @@ -38,6 +38,8 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
"""Loading lat/long data from a csv file in the repo"""
tbl_name = "long_lat"
database = utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -56,7 +58,8 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
pdf["delimited"] = pdf["LAT"].map(str).str.cat(pdf["LON"].map(str), sep=",")
pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
Expand Down Expand Up @@ -88,6 +91,7 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
obj = table(table_name=tbl_name)
obj.main_dttm_col = "datetime"
obj.database = database
obj.schema = schema
obj.filter_select_enabled = True
db.session.merge(obj)
db.session.commit()
Expand Down
8 changes: 6 additions & 2 deletions superset/examples/multiformat_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Dict, Optional, Tuple

import pandas as pd
from sqlalchemy import BigInteger, Date, DateTime, String
from sqlalchemy import BigInteger, Date, DateTime, inspect, String

from superset import app, db
from superset.models.slice import Slice
Expand All @@ -38,6 +38,8 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
"""Loading time series data from a zip file in the repo"""
tbl_name = "multiformat_time_series"
database = get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -55,7 +57,8 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals

pdf.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
Expand All @@ -80,6 +83,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
obj = table(table_name=tbl_name)
obj.main_dttm_col = "ds"
obj.database = database
obj.schema = schema
obj.filter_select_enabled = True
dttm_and_expr_dict: Dict[str, Tuple[Optional[str], None]] = {
"ds": (None, None),
Expand Down
8 changes: 6 additions & 2 deletions superset/examples/paris.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import json

import pandas as pd
from sqlalchemy import String, Text
from sqlalchemy import inspect, String, Text

from superset import db
from superset.utils import core as utils
Expand All @@ -28,6 +28,8 @@
def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) -> None:
tbl_name = "paris_iris_mapping"
database = utils.get_example_database()
engine = database.get_sqla_engine()
schema = inspect(engine).default_schema_name
table_exists = database.has_table_by_name(tbl_name)

if not only_metadata and (not table_exists or force):
Expand All @@ -37,7 +39,8 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) ->

df.to_sql(
tbl_name,
database.get_sqla_engine(),
engine,
schema=schema,
if_exists="replace",
chunksize=500,
dtype={
Expand All @@ -56,6 +59,7 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) ->
tbl = table(table_name=tbl_name)
tbl.description = "Map of Paris"
tbl.database = database
tbl.schema = schema
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
Expand Down
Loading

0 comments on commit 13e32dd

Please sign in to comment.