Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore(sqlalchemy): Remove erroneous SQLAlchemy ORM session.merge operations #24776

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion superset/examples/bart_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def load_bart_lines(only_metadata: bool = False, force: bool = False) -> None:
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "BART lines"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
2 changes: 1 addition & 1 deletion superset/examples/country_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,13 @@ def load_country_map_data(only_metadata: bool = False, force: bool = False) -> N
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "dttm"
obj.database = database
obj.filter_select_enabled = True
if not any(col.metric_name == "avg__2004" for col in obj.metrics):
col = str(column("2004").compile(db.engine))
obj.metrics.append(SqlMetric(metric_name="avg__2004", expression=f"AVG({col})"))
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj
Expand Down
4 changes: 2 additions & 2 deletions superset/examples/css_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def load_css_templates() -> None:
obj = db.session.query(CssTemplate).filter_by(template_name="Flat").first()
if not obj:
obj = CssTemplate(template_name="Flat")
db.session.add(obj)
css = textwrap.dedent(
"""\
.navbar {
Expand All @@ -51,12 +52,12 @@ def load_css_templates() -> None:
"""
)
obj.css = css
db.session.merge(obj)
db.session.commit()

obj = db.session.query(CssTemplate).filter_by(template_name="Courier Black").first()
if not obj:
obj = CssTemplate(template_name="Courier Black")
db.session.add(obj)
css = textwrap.dedent(
"""\
h2 {
Expand Down Expand Up @@ -96,5 +97,4 @@ def load_css_templates() -> None:
"""
)
obj.css = css
db.session.merge(obj)
db.session.commit()
2 changes: 1 addition & 1 deletion superset/examples/deck.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,7 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements

if not dash:
dash = Dashboard()
db.session.add(dash)
dash.published = True
js = POSITION_JSON
pos = json.loads(js)
Expand All @@ -540,5 +541,4 @@ def load_deck_dash() -> None: # pylint: disable=too-many-statements
dash.dashboard_title = title
dash.slug = slug
dash.slices = slices
db.session.merge(dash)
db.session.commit()
2 changes: 1 addition & 1 deletion superset/examples/energy.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def load_energy(
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Energy consumption"
tbl.database = database
tbl.filter_select_enabled = True
Expand All @@ -76,7 +77,6 @@ def load_energy(
SqlMetric(metric_name="sum__value", expression=f"SUM({col})")
)

db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

Expand Down
2 changes: 1 addition & 1 deletion superset/examples/flights.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ def load_flights(only_metadata: bool = False, force: bool = False) -> None:
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Random set of flights in the US"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
print("Done loading table!")
2 changes: 1 addition & 1 deletion superset/examples/long_lat.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def load_long_lat_data(only_metadata: bool = False, force: bool = False) -> None
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "datetime"
obj.database = database
obj.filter_select_enabled = True
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj
Expand Down
2 changes: 1 addition & 1 deletion superset/examples/misc_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def load_misc_dashboard() -> None:

if not dash:
dash = Dashboard()
db.session.add(dash)
js = textwrap.dedent(
"""\
{
Expand Down Expand Up @@ -215,5 +216,4 @@ def load_misc_dashboard() -> None:
dash.position_json = json.dumps(pos, indent=4)
dash.slug = DASH_SLUG
dash.slices = slices
db.session.merge(dash)
db.session.commit()
2 changes: 1 addition & 1 deletion superset/examples/multiformat_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "ds"
obj.database = database
obj.filter_select_enabled = True
Expand All @@ -100,7 +101,6 @@ def load_multiformat_time_series( # pylint: disable=too-many-locals
col.python_date_format = dttm_and_expr[0]
col.database_expression = dttm_and_expr[1]
col.is_dttm = True
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj
Expand Down
2 changes: 1 addition & 1 deletion superset/examples/paris.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,9 @@ def load_paris_iris_geojson(only_metadata: bool = False, force: bool = False) ->
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Map of Paris"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
2 changes: 1 addition & 1 deletion superset/examples/random_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ def load_random_time_series_data(
obj = db.session.query(table).filter_by(table_name=tbl_name).first()
if not obj:
obj = table(table_name=tbl_name, schema=schema)
db.session.add(obj)
obj.main_dttm_col = "ds"
obj.database = database
obj.filter_select_enabled = True
db.session.merge(obj)
db.session.commit()
obj.fetch_metadata()
tbl = obj
Expand Down
2 changes: 1 addition & 1 deletion superset/examples/sf_population_polygons.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def load_sf_population_polygons(
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = "Population density of San Francisco"
tbl.database = database
tbl.filter_select_enabled = True
db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()
3 changes: 1 addition & 2 deletions superset/examples/tabbed_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def load_tabbed_dashboard(_: bool = False) -> None:

if not dash:
dash = Dashboard()
db.session.add(dash)

js = textwrap.dedent(
"""
Expand Down Expand Up @@ -556,6 +557,4 @@ def load_tabbed_dashboard(_: bool = False) -> None:
dash.slices = slices
dash.dashboard_title = "Tabbed Dashboard"
dash.slug = slug

db.session.merge(dash)
db.session.commit()
4 changes: 2 additions & 2 deletions superset/examples/world_bank.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
tbl = db.session.query(table).filter_by(table_name=tbl_name).first()
if not tbl:
tbl = table(table_name=tbl_name, schema=schema)
db.session.add(tbl)
tbl.description = utils.readfile(
os.path.join(get_examples_folder(), "countries.md")
)
Expand All @@ -110,7 +111,6 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
SqlMetric(metric_name=metric, expression=f"{aggr_func}({col})")
)

db.session.merge(tbl)
db.session.commit()
tbl.fetch_metadata()

Expand All @@ -126,6 +126,7 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s

if not dash:
dash = Dashboard()
db.session.add(dash)
dash.published = True
pos = dashboard_positions
slices = update_slice_ids(pos)
Expand All @@ -134,7 +135,6 @@ def load_world_bank_health_n_pop( # pylint: disable=too-many-locals, too-many-s
dash.position_json = json.dumps(pos, indent=4)
dash.slug = slug
dash.slices = slices
db.session.merge(dash)
db.session.commit()


Expand Down
1 change: 0 additions & 1 deletion superset/key_value/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@ def update(self) -> Optional[Key]:
entry.expires_on = self.expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
db.session.merge(entry)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you know if the autoflush(False) when querying for the model influences anything?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michael-s-molina I'm not entirely sure why auto-flushing has been disabled here (and elsewhere in the key/value commands). Maybe @villebro can provide some context as this was added in #19078.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@michael-s-molina I removed the unnecessary auto-flush logic in #26009.

db.session.commit()
return Key(id=entry.id, uuid=entry.uuid)

Expand Down
1 change: 0 additions & 1 deletion superset/key_value/commands/upsert.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ def upsert(self) -> Key:
entry.expires_on = self.expires_on
entry.changed_on = datetime.now()
entry.changed_by_fk = get_user_id()
db.session.merge(entry)
db.session.commit()
return Key(entry.id, entry.uuid)

Expand Down
12 changes: 4 additions & 8 deletions superset/migrations/shared/migrate_viz/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _migrate_temporal_filter(self, rv_data: dict[str, Any]) -> None:
]

@classmethod
def upgrade_slice(cls, slc: Slice) -> Slice:
def upgrade_slice(cls, slc: Slice) -> None:
clz = cls(slc.params)
form_data_bak = copy.deepcopy(clz.data)

Expand All @@ -141,10 +141,9 @@ def upgrade_slice(cls, slc: Slice) -> Slice:
if "form_data" in (query_context := try_load_json(slc.query_context)):
query_context["form_data"] = clz.data
slc.query_context = json.dumps(query_context)
return slc
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previously this was returning the same object as the one being passed in.


@classmethod
def downgrade_slice(cls, slc: Slice) -> Slice:
def downgrade_slice(cls, slc: Slice) -> None:
form_data = try_load_json(slc.params)
if "viz_type" in (form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {})):
slc.params = json.dumps(form_data_bak)
Expand All @@ -153,7 +152,6 @@ def downgrade_slice(cls, slc: Slice) -> Slice:
if "form_data" in query_context:
query_context["form_data"] = form_data_bak
slc.query_context = json.dumps(query_context)
return slc

@classmethod
def upgrade(cls, session: Session) -> None:
Expand All @@ -162,8 +160,7 @@ def upgrade(cls, session: Session) -> None:
slices,
lambda current, total: print(f"Upgraded {current}/{total} charts"),
):
new_viz = cls.upgrade_slice(slc)
session.merge(new_viz)
cls.upgrade_slice(slc)

@classmethod
def downgrade(cls, session: Session) -> None:
Expand All @@ -177,5 +174,4 @@ def downgrade(cls, session: Session) -> None:
slices,
lambda current, total: print(f"Downgraded {current}/{total} charts"),
):
new_viz = cls.downgrade_slice(slc)
session.merge(new_viz)
cls.downgrade_slice(slc)
1 change: 0 additions & 1 deletion superset/migrations/shared/security_converge.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,6 @@ def migrate_roles(
if new_pvm not in role.permissions:
logger.info(f"Add {new_pvm} to {role}")
role.permissions.append(new_pvm)
session.merge(role)

# Delete old permissions
_delete_old_permissions(session, pvm_map)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def upgrade():
for slc in session.query(Slice).all():
if slc.datasource:
slc.perm = slc.datasource.perm
session.merge(slc)
session.commit()
db.session.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def upgrade():
slc.datasource_id = slc.druid_datasource_id
if slc.table_id:
slc.datasource_id = slc.table_id
session.merge(slc)
session.commit()
session.close()

Expand All @@ -69,7 +68,6 @@ def downgrade():
slc.druid_datasource_id = slc.datasource_id
if slc.datasource_type == "table":
slc.table_id = slc.datasource_id
session.merge(slc)
session.commit()
session.close()
op.drop_column("slices", "datasource_id")
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def upgrade():
try:
d = json.loads(slc.params or "{}")
slc.params = json.dumps(d, indent=2, sort_keys=True)
session.merge(slc)
session.commit()
print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}")
except Exception as ex:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ def upgrade():
"/".join(split[:-1]) + "/?form_data=" + parse.quote_plus(json.dumps(d))
)
url.url = newurl
session.merge(url)
session.commit()
print(f"Updating url ({i}/{urls_len})")
session.close()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def upgrade():
del params["latitude"]
del params["longitude"]
slc.params = json.dumps(params)
session.merge(slc)
session.commit()
session.close()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ def upgrade():
)
params["annotation_layers"] = new_layers
slc.params = json.dumps(params)
session.merge(slc)
session.commit()
session.close()

Expand All @@ -86,6 +85,5 @@ def downgrade():
if layers:
params["annotation_layers"] = [layer["value"] for layer in layers]
slc.params = json.dumps(params)
session.merge(slc)
session.commit()
session.close()
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ def upgrade():
pos["v"] = 1

dashboard.position_json = json.dumps(positions, indent=2)
session.merge(dashboard)
session.commit()

session.close()
Expand All @@ -85,6 +84,5 @@ def downgrade():
pos["v"] = 0

dashboard.position_json = json.dumps(positions, indent=2)
session.merge(dashboard)
session.commit()
pass
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def upgrade():
params["metrics"] = [params.get("metric")]
del params["metric"]
slc.params = json.dumps(params, indent=2, sort_keys=True)
session.merge(slc)
session.commit()
print(f"Upgraded ({i}/{slice_len}): {slc.slice_name}")
except Exception as ex:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -647,7 +647,6 @@ def upgrade():

sorted_by_key = collections.OrderedDict(sorted(v2_layout.items()))
dashboard.position_json = json.dumps(sorted_by_key, indent=2)
session.merge(dashboard)
session.commit()
else:
print(f"Skip converted dash_id: {dashboard.id}")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def upgrade():
dashboard.id, len(original_text), len(text)
)
)
session.merge(dashboard)
session.commit()


Expand Down
Loading