Skip to content

Commit

Permalink
Merge branch 'native-versioning/tx-table-trigger'
Browse files Browse the repository at this point in the history
  • Loading branch information
kvesteri committed Oct 1, 2014
2 parents dc313b0 + 8686b1c commit c624bbd
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 115 deletions.
3 changes: 2 additions & 1 deletion benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ def test_versioning(

engine = create_engine(dns)
# engine.echo = True
connection = engine.connect()

class Article(Model):
__tablename__ = 'article'
Expand All @@ -79,6 +78,8 @@ class Tag(Model):

sa.orm.configure_mappers()

connection = engine.connect()

Model.metadata.create_all(connection)

Session = sessionmaker(bind=connection)
Expand Down
12 changes: 12 additions & 0 deletions sqlalchemy_continuum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,12 @@ def make_versioned(
'before_cursor_execute',
manager.track_association_operations
)
if manager.options['native_versioning']:
sa.event.listen(
sa.pool.Pool,
'connect',
manager.on_connect
)


def remove_versioning(
Expand Down Expand Up @@ -96,3 +102,9 @@ def remove_versioning(
'before_cursor_execute',
manager.track_association_operations
)
if manager.options['native_versioning']:
sa.event.remove(
sa.pool.Pool,
'connect',
manager.on_connect
)
70 changes: 47 additions & 23 deletions sqlalchemy_continuum/dialects/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,30 @@
"""

temporary_transaction_sql = """
CREATE TEMP TABLE {temporary_transaction_table}
CREATE TEMP TABLE IF NOT EXISTS {temporary_transaction_table}
({transaction_table_columns})
ON COMMIT DROP;
ON COMMIT DELETE ROWS;
"""

insert_temporary_transaction_sql = """
INSERT INTO {temporary_transaction_table} ({transaction_table_columns})
VALUES ({transaction_values});
"""

temp_transaction_trigger_sql = """
CREATE TRIGGER transaction_trigger
AFTER INSERT ON {transaction_table}
FOR EACH ROW EXECUTE PROCEDURE transaction_temp_table_generator()
"""

procedure_sql = """
CREATE OR REPLACE FUNCTION {procedure_name}() RETURNS TRIGGER AS $$
DECLARE transaction_id_value INT;
BEGIN
BEGIN
transaction_id_value = (SELECT id FROM temporary_transaction);
EXCEPTION
WHEN others THEN
INSERT INTO transaction (native_tx_id)
VALUES (txid_current()) RETURNING id INTO transaction_id_value;
{temporary_transaction_sql}
{insert_temporary_transaction_sql}
END;
transaction_id_value = (SELECT id FROM temporary_transaction);
IF transaction_id_value IS NULL THEN
RETURN NEW;
END IF;
IF (TG_OP = 'INSERT') THEN
{after_insert}
Expand Down Expand Up @@ -136,10 +136,7 @@ def transaction_table_name(self):

@property
def temporary_transaction_table_name(self):
if self.table.schema:
return '%s.temporary_transaction' % self.table.schema
else:
return 'temporary_transaction'
return 'temporary_transaction'

@property
def version_table_name(self):
Expand Down Expand Up @@ -359,20 +356,30 @@ def __str__(self):
)


class CreateTemporaryTransactionTableSQL(SQLConstruct):
class TransactionSQLConstruct(object):
def __init__(self, **kwargs):
self.__dict__.update(kwargs)


class CreateTemporaryTransactionTableSQL(TransactionSQLConstruct):
table_name = 'temporary_transaction'

def __str__(self):
return temporary_transaction_sql.format(
temporary_transaction_table=self.temporary_transaction_table_name,
temporary_transaction_table=self.table_name,
transaction_table_columns='id BIGINT, PRIMARY KEY(id)'
)


class InsertTemporaryTransactionSQL(SQLConstruct):
class InsertTemporaryTransactionSQL(TransactionSQLConstruct):
table_name = 'temporary_transaction'
transaction_values = 'transaction_id_value'

def __str__(self):
return insert_temporary_transaction_sql.format(
temporary_transaction_table=self.temporary_transaction_table_name,
temporary_transaction_table=self.table_name,
transaction_table_columns='id',
transaction_values='transaction_id_value'
transaction_values=self.transaction_values
)


Expand All @@ -394,10 +401,10 @@ def __str__(self):
after_update=after_update,
after_delete=after_delete,
temporary_transaction_sql=(
CreateTemporaryTransactionTableSQL(**args)
CreateTemporaryTransactionTableSQL()
),
insert_temporary_transaction_sql=(
InsertTemporaryTransactionSQL(**args)
InsertTemporaryTransactionSQL()
),
upsert_insert=InsertUpsertSQL(**args),
upsert_update=UpdateUpsertSQL(**args),
Expand All @@ -406,6 +413,23 @@ def __str__(self):
return sql


class TransactionTriggerSQL(object):
def __init__(self, tx_class):
self.table = tx_class.__table__

@property
def transaction_table_name(self):
if self.table.schema:
return '%s.transaction' % self.table.schema
else:
return 'transaction'

def __str__(self):
return temp_transaction_trigger_sql.format(
transaction_table=self.transaction_table_name
)


def create_versioning_trigger_listeners(manager, cls):
sa.event.listen(
cls.__table__,
Expand Down
8 changes: 8 additions & 0 deletions sqlalchemy_continuum/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,14 @@ def unit_of_work(self, session):
self.units_of_work[conn] = uow
return uow

def on_connect(self, dbapi_conn, connection_record):
from .dialects.postgresql import CreateTemporaryTransactionTableSQL

cursor = dbapi_conn.cursor()
cursor.execute(str(CreateTemporaryTransactionTableSQL()))
dbapi_conn.commit()
cursor.close()

def before_flush(self, session, flush_context, instances):
"""
Before flush listener for SQLAlchemy sessions. If this manager has
Expand Down
50 changes: 50 additions & 0 deletions sqlalchemy_continuum/transaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles

from .dialects.postgresql import (
CreateTemporaryTransactionTableSQL,
InsertTemporaryTransactionSQL,
TransactionTriggerSQL
)
from .exc import ImproperlyConfigured
from .factory import ModelFactory

Expand Down Expand Up @@ -55,6 +60,48 @@ def changed_entities(self):
return entities


procedure_sql = """
CREATE OR REPLACE FUNCTION transaction_temp_table_generator()
RETURNS TRIGGER AS $$
BEGIN
{temporary_transaction_sql}
INSERT INTO temporary_transaction (id) VALUES (NEW.id);
RETURN NEW;
END;
$$
LANGUAGE plpgsql
"""


def create_triggers(cls):
sa.event.listen(
cls.__table__,
'after_create',
sa.schema.DDL(
procedure_sql.format(
temporary_transaction_sql=CreateTemporaryTransactionTableSQL(),
insert_temporary_transaction_sql=(
InsertTemporaryTransactionSQL(
transaction_id_values='NEW.id'
)
),
)
)
)
sa.event.listen(
cls.__table__,
'after_create',
sa.schema.DDL(str(TransactionTriggerSQL(cls)))
)
sa.event.listen(
cls.__table__,
'after_drop',
sa.schema.DDL(
'DROP FUNCTION IF EXISTS transaction_temp_table_generator()'
)
)


class TransactionFactory(ModelFactory):
model_name = 'Transaction'

Expand Down Expand Up @@ -129,4 +176,7 @@ def __repr__(self):
for field, value in field_values.items()
)
)

if manager.options['native_versioning']:
create_triggers(Transaction)
return Transaction
83 changes: 10 additions & 73 deletions sqlalchemy_continuum/unit_of_work.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,83 +110,20 @@ def create_transaction(self, session):
args = self.transaction_args(session)

Transaction = self.manager.transaction_cls
table = Transaction.__table__
self.current_transaction = Transaction()

if self.manager.options['native_versioning']:
has_transaction_initialized = bool(session.execute(
'''SELECT 1 FROM pg_catalog.pg_class
WHERE relname = 'temporary_transaction'
'''
).scalar())
if has_transaction_initialized:
tx_id = (
session.execute('SELECT id FROM temporary_transaction')
.scalar()
)
set_committed_value(self.current_transaction, 'id', tx_id)
else:
criteria = {'native_tx_id': sa.func.txid_current()}
args.update(criteria)

query = (
table.insert()
.values(**args)
.returning(*map(sa.text, list(args.keys()) + ['id']))
)

values = session.execute(query).fetchone()
for key, value in values.items():
set_committed_value(self.current_transaction, key, value)

session.execute(
'''
CREATE TEMP TABLE temporary_transaction
(id BIGINT, PRIMARY KEY(id))
ON COMMIT DROP
'''
)
session.execute('''
INSERT INTO temporary_transaction (id)
VALUES (:id)
''',
{'id': self.current_transaction.id}
)
self.merge_transaction(session, self.current_transaction)
else:
for key, value in args.items():
setattr(self.current_transaction, key, value)
if not self.version_session:
self.version_session = sa.orm.session.Session(
bind=session.connection()
)
self.version_session.add(self.current_transaction)
self.version_session.flush()
self.version_session.expunge(self.current_transaction)
session.add(self.current_transaction)
for key, value in args.items():
setattr(self.current_transaction, key, value)
if not self.version_session:
self.version_session = sa.orm.session.Session(
bind=session.connection()
)
self.version_session.add(self.current_transaction)
self.version_session.flush()
self.version_session.expunge(self.current_transaction)
session.add(self.current_transaction)
return self.current_transaction

def merge_transaction(self, session, transaction):
Transaction = self.manager.transaction_cls
state = sa.inspect(self.current_transaction)
state.key = (
Transaction, (self.current_transaction.id,)
)

if hasattr(session, 'hash_key'):
session_id = session.hash_key
else:
# We need this hack when user is using for example
# Flask-SQLAlchemy's scoped session
objs = list(session)
if not objs:
raise Exception('Could not get session id.')

session_id = sa.inspect(objs[0]).session_id

state.session_id = session_id
session.merge(self.current_transaction, load=False)

def get_or_create_version_object(self, target):
"""
Return version object for given parent object. If no version object
Expand Down
6 changes: 4 additions & 2 deletions tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,14 @@ def setup_method(self, method):

self.engine = create_engine(get_dns_from_driver(self.driver))
# self.engine.echo = True
self.connection = self.engine.connect()

self.create_models()

sa.orm.configure_mappers()

self.connection = self.engine.connect()



if hasattr(self, 'Article'):
try:
self.ArticleVersion = version_class(self.Article)
Expand Down
2 changes: 1 addition & 1 deletion tests/schema/test_update_end_transaction_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def _insert(self, values):
stmt = table.insert().values(values)
self.session.execute(stmt)

def test_something(self):
def test_update_end_transaction_id(self):
table = version_class(self.Article).__table__
self._insert(
{
Expand Down
15 changes: 0 additions & 15 deletions tests/test_raw_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,6 @@ def assert_has_single_transaction(self):
.count() == 1
)

def test_single_statement(self):
self.session.execute(
"INSERT INTO article (name) VALUES ('some article')"
)
self.assert_has_single_transaction()

def test_multiple_statements(self):
self.session.execute(
"INSERT INTO article (name) VALUES ('some article')"
)
self.session.execute(
"INSERT INTO article (name) VALUES ('some article')"
)
self.assert_has_single_transaction()

def test_flush_after_raw_insert(self):
self.session.execute(
"INSERT INTO article (name) VALUES ('some article')"
Expand Down

0 comments on commit c624bbd

Please sign in to comment.