Skip to content

Commit

Permalink
feat: Adds CLI commands to execute viz migrations
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-s-molina committed Sep 15, 2023
1 parent e1ddba9 commit 3585b25
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 24 deletions.
89 changes: 89 additions & 0 deletions superset/cli/viz_migrations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from enum import Enum

import click
from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup
from flask.cli import with_appcontext


class VizTypes(str, Enum):
treemap = "treemap"
dual_line = "dual_line"
area = "area"
pivot_table = "pivot_table"


@click.group()
def viz_migrations() -> None:
"""
Migrates a viz from one type to another.
"""


@viz_migrations.command()
@with_appcontext
@optgroup.group(
"Grouped options",
cls=RequiredMutuallyExclusiveOptionGroup,
)
@optgroup.option(
"--type",
"-t",
help=f"The viz type to migrate: {', '.join([type for type in VizTypes])}",
)
def upgrade(type: str) -> None:
"""Upgrade a viz to the latest version."""
migrate_viz(VizTypes(type))


@viz_migrations.command()
@with_appcontext
@optgroup.group(
"Grouped options",
cls=RequiredMutuallyExclusiveOptionGroup,
)
@optgroup.option(
"--type",
"-t",
help=f"The viz type to migrate: {', '.join([type for type in VizTypes])}",
)
def downgrade(type: str) -> None:
"""Downgrades a viz to the previous version."""
migrate_viz(VizTypes(type), downgrade=True)


def migrate_viz(type: VizTypes, downgrade: bool = False) -> None:
"""Migrates a viz from one type to another."""
from superset.migrations.shared.migrate_viz.base import MigrateViz
from superset.migrations.shared.migrate_viz.processors import (
MigrateAreaChart,
MigrateDualLine,
MigratePivotTable,
MigrateTreeMap,
)

migrations = {
VizTypes.treemap: MigrateTreeMap,
VizTypes.dual_line: MigrateDualLine,
VizTypes.area: MigrateAreaChart,
VizTypes.pivot_table: MigratePivotTable,
}
if downgrade:
migrations[type].downgrade()
else:
migrations[type].upgrade()
30 changes: 6 additions & 24 deletions superset/migrations/shared/migrate_viz/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,26 +20,12 @@
import json
from typing import Any

from alembic import op
from sqlalchemy import and_, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import and_

from superset import conf, db, is_feature_enabled
from superset.constants import TimeGrain
from superset.migrations.shared.utils import paginated_update, try_load_json

Base = declarative_base()


class Slice(Base): # type: ignore
__tablename__ = "slices"

id = Column(Integer, primary_key=True)
slice_name = Column(String(250))
viz_type = Column(String(250))
params = Column(Text)
query_context = Column(Text)

from superset.models.slice import Slice

FORM_DATA_BAK_FIELD_NAME = "form_data_bak"

Expand Down Expand Up @@ -157,23 +143,19 @@ def downgrade_slice(cls, slc: Slice) -> Slice:

@classmethod
def upgrade(cls) -> None:
bind = op.get_bind()
session = db.Session(bind=bind)
slices = session.query(Slice).filter(Slice.viz_type == cls.source_viz_type)
slices = db.session.query(Slice).filter(Slice.viz_type == cls.source_viz_type)
for slc in paginated_update(
slices,
lambda current, total: print(
f" Updating {current}/{total} charts", end="\r"
),
):
new_viz = cls.upgrade_slice(slc)
session.merge(new_viz)
db.session.merge(new_viz)

@classmethod
def downgrade(cls) -> None:
bind = op.get_bind()
session = db.Session(bind=bind)
slices = session.query(Slice).filter(
slices = db.session.query(Slice).filter(
and_(
Slice.viz_type == cls.target_viz_type,
Slice.params.like(f"%{FORM_DATA_BAK_FIELD_NAME}%"),
Expand All @@ -186,4 +168,4 @@ def downgrade(cls) -> None:
),
):
new_viz = cls.downgrade_slice(slc)
session.merge(new_viz)
db.session.merge(new_viz)

0 comments on commit 3585b25

Please sign in to comment.