Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PsqlDosBackend: Fix changes not persisted after iterall and iterdict #6134

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 12 additions & 13 deletions aiida/storage/psql_dos/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,24 +172,21 @@ def _clear(self) -> None:

with self.migrator_context(self._profile) as migrator:

# First clear the contents of the database
with self.transaction() as session:

# Close the session otherwise the ``delete_tables`` call will hang as there will be an open connection
# to the PostgreSQL server and it will block the deletion and the command will hang.
self.get_session().close()
exclude_tables = [migrator.alembic_version_tbl_name, 'db_dbsetting']
migrator.delete_all_tables(exclude_tables=exclude_tables)
# Close the session otherwise the ``delete_tables`` call will hang as there will be an open connection
# to the PostgreSQL server and it will block the deletion and the command will hang.
self.get_session().close()
exclude_tables = [migrator.alembic_version_tbl_name, 'db_dbsetting']
migrator.delete_all_tables(exclude_tables=exclude_tables)

# Clear out all references to database model instances which are now invalid.
session.expunge_all()
# Clear out all references to database model instances which are now invalid.
self.get_session().expunge_all()

# Now reset and reinitialise the repository
migrator.reset_repository()
migrator.initialise_repository()
repository_uuid = migrator.get_repository_uuid()

with self.transaction():
with self.transaction() as session:
session.execute(
DbSetting.__table__.update().where(DbSetting.key == REPOSITORY_UUID_KEY
).values(val=repository_uuid)
Expand Down Expand Up @@ -243,13 +240,15 @@ def transaction(self) -> Iterator[Session]:
"""
session = self.get_session()
if session.in_transaction():
with session.begin_nested():
with session.begin_nested() as savepoint:
yield session
savepoint.commit()
session.commit()
else:
with session.begin():
with session.begin_nested():
with session.begin_nested() as savepoint:
yield session
savepoint.commit()

@property
def in_transaction(self) -> bool:
Expand Down
10 changes: 4 additions & 6 deletions aiida/storage/psql_dos/orm/querybuilder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,8 @@ def iterall(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[Li
# on the session when a yielded row is mutated. This would reset the cursor invalidating it and causing an
# exception to be raised in the next batch of rows in the iteration.
# See https://github.com/python/mypy/issues/10109 for the reason of the type warning.
in_nested_transaction = session.in_nested_transaction()

with nullcontext() if in_nested_transaction else session.begin_nested(): # type: ignore[attr-defined]
with nullcontext() if session.in_nested_transaction() else self._backend.transaction(
): # type: ignore[attr-defined]
for resultrow in session.execute(stmt):
yield [self.to_backend(rowitem) for rowitem in resultrow]

Expand All @@ -188,9 +187,8 @@ def iterdict(self, data: QueryDictType, batch_size: Optional[int]) -> Iterable[D
# on the session when a yielded row is mutated. This would reset the cursor invalidating it and causing an
# exception to be raised in the next batch of rows in the iteration.
# See https://github.com/python/mypy/issues/10109 for the reason of the type warning.
in_nested_transaction = session.in_nested_transaction()

with nullcontext() if in_nested_transaction else session.begin_nested(): # type: ignore[attr-defined]
with nullcontext() if session.in_nested_transaction() else self._backend.transaction(
): # type: ignore[attr-defined]
for row in self.get_session().execute(stmt):
# build the yield result
yield_result: Dict[str, Dict[str, Any]] = {}
Expand Down
13 changes: 12 additions & 1 deletion aiida/storage/sqlite_zip/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,19 @@ def users(self):
def _clear(self) -> None:
raise ReadOnlyError()

@contextmanager
def transaction(self):
raise ReadOnlyError()
session = self.get_session()
if session.in_transaction():
with session.begin_nested() as savepoint:
yield session
savepoint.commit()
session.commit()
else:
with session.begin():
with session.begin_nested() as savepoint:
yield session
savepoint.commit()

@property
def in_transaction(self) -> bool:
Expand Down
32 changes: 32 additions & 0 deletions tests/orm/test_querybuilder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1567,6 +1567,38 @@ def test_iterall_with_store_group(self):
for pk, pk_clone in zip(pks, [e[1] for e in sorted(pks_clone)]):
assert orm.load_node(pk) == orm.load_node(pk_clone)

@pytest.mark.usefixtures('aiida_profile_clean')
def test_iterall_persistence(self, manager):
"""Test that mutations made during ``QueryBuilder.iterall`` context are automatically committed and persisted.

This is a regression test for https://github.com/aiidateam/aiida-core/issues/6133 .
"""
count = 10

# Create number of nodes with specific extra
for _ in range(count):
node = orm.Data().store()
node.base.extras.set('testing', True)

query = orm.QueryBuilder().append(orm.Data, filters={'extras': {'has_key': 'testing'}})
assert query.count() == count

# Unload and reload the storage, which will reset the session and check that the nodes with extras still exist
manager.reset_profile_storage()
manager.get_profile_storage()
assert query.count() == count

# Delete the extras and check that the query now matches 0
for [node] in orm.QueryBuilder().append(orm.Data).iterall(batch_size=2):
node.base.extras.delete('testing')

assert query.count() == 0

# Finally, reset the storage again and verify the changes have been persisted
manager.reset_profile_storage()
manager.get_profile_storage()
assert query.count() == 0


class TestManager:

Expand Down