Skip to content

Commit

Permalink
fix: fix table existence validation function (#11066)
Browse files Browse the repository at this point in the history
* Fix table existance validation function

* Drop left over table name index in mysql db

* Do not modify model

Co-authored-by: bogdan kyryliuk <[email protected]>
  • Loading branch information
bkyryliuk and bogdan-dbx authored Sep 29, 2020
1 parent 0409b12 commit 03eebd3
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 6 deletions.
4 changes: 2 additions & 2 deletions superset/datasets/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,11 +70,11 @@ def validate(self) -> None:
exceptions: List[ValidationError] = list()
database_id = self._properties["database"]
table_name = self._properties["table_name"]
schema = self._properties.get("schema", "")
schema = self._properties.get("schema", None)
owner_ids: Optional[List[int]] = self._properties.get("owners")

# Validate uniqueness
if not DatasetDAO.validate_uniqueness(database_id, table_name):
if not DatasetDAO.validate_uniqueness(database_id, schema, table_name):
exceptions.append(DatasetExistsValidationError(table_name))

# Validate/Populate database
Expand Down
10 changes: 7 additions & 3 deletions superset/datasets/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,9 @@ def get_related_objects(database_id: int) -> Dict[str, Any]:
return dict(charts=charts, dashboards=dashboards)

@staticmethod
def validate_table_exists(database: Database, table_name: str, schema: str) -> bool:
def validate_table_exists(
database: Database, table_name: str, schema: Optional[str]
) -> bool:
try:
database.get_table(table_name, schema=schema)
return True
Expand All @@ -83,9 +85,11 @@ def validate_table_exists(database: Database, table_name: str, schema: str) -> b
return False

@staticmethod
def validate_uniqueness(database_id: int, name: str) -> bool:
def validate_uniqueness(database_id: int, schema: Optional[str], name: str) -> bool:
dataset_query = db.session.query(SqlaTable).filter(
SqlaTable.table_name == name, SqlaTable.database_id == database_id
SqlaTable.table_name == name,
SqlaTable.schema == schema,
SqlaTable.database_id == database_id,
)
return not db.session.query(dataset_query.exists()).scalar()

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Delete table_name unique constraint in mysql
Revision ID: 18532d70ab98
Revises: e5ef6828ac4e
Create Date: 2020-09-25 10:56:13.711182
"""

# revision identifiers, used by Alembic.
revision = "18532d70ab98"
down_revision = "e5ef6828ac4e"

from alembic import op


def upgrade():
try:
# index only exists in mysql db
with op.get_context().autocommit_block():
op.drop_constraint("table_name", "tables", type_="unique")
except Exception as ex:
print(ex)


def downgrade():
pass
33 changes: 32 additions & 1 deletion tests/datasets/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
)
from superset.extensions import db, security_manager
from superset.models.core import Database
from superset.utils.core import get_example_database, get_main_database
from superset.utils.core import backend, get_example_database, get_main_database
from superset.utils.dict_import_export import export_to_dict
from superset.views.base import generate_download_headers
from tests.base_tests import SupersetTestCase
from tests.conftest import CTAS_SCHEMA_NAME


class TestDatasetApi(SupersetTestCase):
Expand Down Expand Up @@ -387,6 +388,36 @@ def test_create_dataset_validate_uniqueness(self):
data, {"message": {"table_name": ["Datasource birth_names already exists"]}}
)

def test_create_dataset_same_name_different_schema(self):
if backend() == "sqlite":
# sqlite doesn't support schemas
return

example_db = get_example_database()
example_db.get_sqla_engine().execute(
f"CREATE TABLE {CTAS_SCHEMA_NAME}.birth_names AS SELECT 2 as two"
)

self.login(username="admin")
table_data = {
"database": example_db.id,
"schema": CTAS_SCHEMA_NAME,
"table_name": "birth_names",
}

uri = "api/v1/dataset/"
rv = self.post_assert_metric(uri, table_data, "post")
self.assertEqual(rv.status_code, 201)

# cleanup
data = json.loads(rv.data.decode("utf-8"))
uri = f'api/v1/dataset/{data.get("id")}'
rv = self.client.delete(uri)
self.assertEqual(rv.status_code, 200)
example_db.get_sqla_engine().execute(
f"DROP TABLE {CTAS_SCHEMA_NAME}.birth_names"
)

def test_create_dataset_validate_database(self):
"""
Dataset API: Test create dataset validate database exists
Expand Down

0 comments on commit 03eebd3

Please sign in to comment.