Skip to content

Commit

Permalink
Receive session as a param
Browse files Browse the repository at this point in the history
  • Loading branch information
michael-s-molina committed Sep 15, 2023
1 parent e997633 commit 5adc898
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 30 deletions.
6 changes: 4 additions & 2 deletions superset/cli/viz_migrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from click_option_group import optgroup, RequiredMutuallyExclusiveOptionGroup
from flask.cli import with_appcontext

from superset import db


class VizTypes(str, Enum):
TREEMAP = "treemap"
Expand Down Expand Up @@ -84,6 +86,6 @@ def migrate_viz(viz_type: VizTypes, is_downgrade: bool = False) -> None:
VizTypes.PIVOT_TABLE: MigratePivotTable,
}
if is_downgrade:
migrations[viz_type].downgrade()
migrations[viz_type].downgrade(db.session)
else:
migrations[viz_type].upgrade()
migrations[viz_type].upgrade(db.session)
46 changes: 35 additions & 11 deletions superset/migrations/shared/migrate_viz/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,26 @@
import json
from typing import Any

from sqlalchemy import and_
from sqlalchemy import and_, Column, Integer, String, Text
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session

from superset import conf, db, is_feature_enabled
from superset import conf, is_feature_enabled
from superset.constants import TimeGrain
from superset.migrations.shared.utils import paginated_update, try_load_json
from superset.models.slice import Slice

Base = declarative_base()


class Slice(Base): # type: ignore
__tablename__ = "slices"

id = Column(Integer, primary_key=True)
slice_name = Column(String(250))
viz_type = Column(String(250))
params = Column(Text)
query_context = Column(Text)


FORM_DATA_BAK_FIELD_NAME = "form_data_bak"

Expand Down Expand Up @@ -142,20 +156,30 @@ def downgrade_slice(cls, slc: Slice) -> Slice:
return slc

@classmethod
def upgrade(cls) -> None:
slices = db.session.query(Slice).filter(Slice.viz_type == cls.source_viz_type)
for slc in slices:
def upgrade(cls, session: Session) -> None:
slices = session.query(Slice).filter(Slice.viz_type == cls.source_viz_type)
for slc in paginated_update(
slices,
lambda current, total: print(
f" Updating {current}/{total} charts", end="\r"
),
):
new_viz = cls.upgrade_slice(slc)
db.session.merge(new_viz)
session.merge(new_viz)

@classmethod
def downgrade(cls) -> None:
slices = db.session.query(Slice).filter(
def downgrade(cls, session: Session) -> None:
slices = session.query(Slice).filter(
and_(
Slice.viz_type == cls.target_viz_type,
Slice.params.like(f"%{FORM_DATA_BAK_FIELD_NAME}%"),
)
)
for slc in slices:
for slc in paginated_update(
slices,
lambda current, total: print(
f" Downgrading {current}/{total} charts", end="\r"
),
):
new_viz = cls.downgrade_slice(slc)
db.session.merge(new_viz)
session.merge(new_viz)
7 changes: 2 additions & 5 deletions superset/migrations/shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,17 +97,14 @@ def paginated_update(
query: Query,
print_page_progress: Optional[Union[Callable[[int, int], None], bool]] = None,
batch_size: int = DEFAULT_BATCH_SIZE,
session: Session = None,
) -> Iterator[Any]:
"""
Update models in small batches so we don't have to load everything in memory.
"""

total = query.count()
processed = 0

if not session:
session = inspect(query).session

session: Session = inspect(query).session
result = session.execute(query)

if print_page_progress is None or print_page_progress is True:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from alembic import op
from sqlalchemy.dialects.mysql.base import MySQLDialect

from superset import db
from superset.migrations.shared.migrate_viz import MigrateTreeMap

# revision identifiers, used by Alembic.
Expand All @@ -32,16 +33,21 @@


def upgrade():
bind = op.get_bind()

# Ensure `slice.params` and `slice.query_context`` in MySQL is MEDIUMTEXT
# before migration, as the migration will save a duplicate form_data backup
# which may significantly increase the size of these fields.
if isinstance(op.get_bind().dialect, MySQLDialect):
if isinstance(bind.dialect, MySQLDialect):
# If the columns are already MEDIUMTEXT, this is a no-op
op.execute("ALTER TABLE slices MODIFY params MEDIUMTEXT")
op.execute("ALTER TABLE slices MODIFY query_context MEDIUMTEXT")

MigrateTreeMap.upgrade()
session = db.Session(bind=bind)
MigrateTreeMap.upgrade(session)


def downgrade():
MigrateTreeMap.downgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateTreeMap.downgrade(session)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
Create Date: 2022-06-13 14:17:51.872706
"""
from alembic import op

from superset import db
from superset.migrations.shared.migrate_viz import MigrateAreaChart

# revision identifiers, used by Alembic.
Expand All @@ -29,8 +32,12 @@


def upgrade():
MigrateAreaChart.upgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateAreaChart.upgrade(session)


def downgrade():
MigrateAreaChart.downgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateAreaChart.downgrade(session)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
Create Date: 2023-08-06 09:02:10.148992
"""
from alembic import op

from superset import db
from superset.migrations.shared.migrate_viz import MigratePivotTable

# revision identifiers, used by Alembic.
Expand All @@ -29,8 +32,12 @@


def upgrade():
MigratePivotTable.upgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigratePivotTable.upgrade(session)


def downgrade():
MigratePivotTable.downgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigratePivotTable.downgrade(session)
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from alembic import op
from sqlalchemy.dialects.mysql.base import MySQLDialect

from superset import db
from superset.migrations.shared.migrate_viz import MigrateTreeMap

# revision identifiers, used by Alembic.
Expand All @@ -32,16 +33,21 @@


def upgrade():
bind = op.get_bind()

# Ensure `slice.params` and `slice.query_context`` in MySQL is MEDIUMTEXT
# before migration, as the migration will save a duplicate form_data backup
# which may significantly increase the size of these fields.
if isinstance(op.get_bind().dialect, MySQLDialect):
if isinstance(bind.dialect, MySQLDialect):
# If the columns are already MEDIUMTEXT, this is a no-op
op.execute("ALTER TABLE slices MODIFY params MEDIUMTEXT")
op.execute("ALTER TABLE slices MODIFY query_context MEDIUMTEXT")

MigrateTreeMap.upgrade()
session = db.Session(bind=bind)
MigrateTreeMap.upgrade(session)


def downgrade():
MigrateTreeMap.downgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateTreeMap.downgrade(session)
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
Create Date: 2023-06-08 11:34:36.241939
"""
from alembic import op

from superset import db

# revision identifiers, used by Alembic.
revision = "ae58e1e58e5c"
Expand All @@ -30,8 +33,12 @@


def upgrade():
MigrateDualLine.upgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateDualLine.upgrade(session)


def downgrade():
MigrateDualLine.downgrade()
bind = op.get_bind()
session = db.Session(bind=bind)
MigrateDualLine.downgrade(session)

0 comments on commit 5adc898

Please sign in to comment.