Skip to content

Commit

Permalink
track cloned connections
Browse files Browse the repository at this point in the history
Hotfix/track cloned connections (#2)

Hotfix/track cloned connections (#3)

change iterations

use only ConnectionFairy as indexes

create new dict entry for clones

create entry using old connection

create entry using session

revert to event listening strategy

remove cloned connections on rollback

treat None case when cleaning connections

Hotfix/track cloned connections (#2)

Hotfix/track cloned connections (#3)

change iterations

use only ConnectionFairy as indexes

create new dict entry for clones

create entry using old connection

create entry using session

revert to event listening strategy

remove cloned connections on rollback

treat None case when cleaning connections
  • Loading branch information
Fernando Cezar authored and Aleksandr Bogdanov committed May 16, 2018
1 parent f98469d commit acef01b
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 2 deletions.
12 changes: 12 additions & 0 deletions sqlalchemy_continuum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def make_versioned(
manager.track_association_operations
)

sa.event.listen(
sa.engine.Engine,
'set_connection_execution_options',
manager.track_cloned_connections
)


def remove_versioning(
mapper=sa.orm.mapper,
Expand All @@ -96,3 +102,9 @@ def remove_versioning(
'before_cursor_execute',
manager.track_association_operations
)

sa.event.remove(
sa.engine.Engine,
'set_connection_execution_options',
manager.track_cloned_connections
)
38 changes: 36 additions & 2 deletions sqlalchemy_continuum/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,15 @@ def wrapper(self, mapper, connection, target):
try:
uow = self.units_of_work[conn]
except KeyError:
uow = self.units_of_work[conn.engine]
try:
uow = self.units_of_work[conn.engine]
except KeyError:
for connection in self.units_of_work.keys():
if connection.connection is conn.connection:
uow = self.unit_of_work(session)
break # The ConnectionFairy is the same, this connection is a clone
else:
raise KeyError
return func(self, uow, target)
return wrapper

Expand Down Expand Up @@ -357,11 +365,20 @@ def clear(self, session):
if session.transaction.nested:
return
conn = self.session_connection_map.pop(session, None)
if conn is None:
return

if conn in self.units_of_work:
uow = self.units_of_work[conn]
uow.reset(session)
del self.units_of_work[conn]

for connection in dict(self.units_of_work).keys():
if conn.connection is connection.connection:
uow = self.units_of_work[connection]
uow.reset(session)
del self.units_of_work[connection]

def append_association_operation(self, conn, table_name, params, op):
"""
Append history association operation to pending_statements list.
Expand All @@ -375,9 +392,26 @@ def append_association_operation(self, conn, table_name, params, op):
try:
uow = self.units_of_work[conn]
except KeyError:
uow = self.units_of_work[conn.engine]
try:
uow = self.units_of_work[conn.engine]
except KeyError:
for connection in self.units_of_work.keys():
if connection.connection is conn.connection:
uow = self.unit_of_work(conn.session)
break # The ConnectionFairy is the same, this connection is a clone
else:
raise KeyError
uow.pending_statements.append(stmt)

def track_cloned_connections(self, c, opt):
"""
Track cloned connections from association tables.
"""
if c not in self.units_of_work.keys():
for connection, uow in dict(self.units_of_work).items():
if connection.connection is c.connection: # ConnectionFairy is the same - this is a clone
self.units_of_work[c] = uow

def track_association_operations(
self, conn, cursor, statement, parameters, context, executemany
):
Expand Down

0 comments on commit acef01b

Please sign in to comment.