Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add column property in the version table #301

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions sqlalchemy_continuum/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
from functools import wraps

import sqlalchemy as sa
from sqlalchemy_continuum.expression_reflector import VersionExpressionReflector
from sqlalchemy_continuum.utils import is_table_column
from sqlalchemy_utils.functions import get_declarative_base

from .dialects.postgresql import create_versioning_trigger_listeners
Expand Down Expand Up @@ -189,6 +191,7 @@ def configure_versioned_classes(self):
self.build_relationships(pending_classes_copies)
self.enable_active_history(pending_classes_copies)
self.create_column_aliases(pending_classes_copies)
self.create_column_properties(pending_classes_copies)

def enable_active_history(self, version_classes):
"""
Expand Down Expand Up @@ -220,3 +223,23 @@ def create_column_aliases(self, version_classes):
continue

version_class_mapper.add_property(key, sa.orm.column_property(version_class_column))

def create_column_properties(self, version_classes):
"""
Create equivalent column_property() on the version class (as it is on the parent model)

This does not handle the simple column aliases - just expressions
"""
for cls in version_classes:
model_mapper = sa.inspect(cls)
version_class = self.manager.version_class_map.get(cls)
if not version_class:
continue

version_class_mapper = sa.inspect(version_class)
reflector = VersionExpressionReflector()
for key, column in model_mapper.columns.items():
if is_table_column(column): # We ignore simple table columns
continue
version_column = reflector(column)
version_class_mapper.add_property(key, sa.orm.column_property(version_column))
39 changes: 27 additions & 12 deletions sqlalchemy_continuum/expression_reflector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,7 @@


class VersionExpressionReflector(sa.sql.visitors.ReplacingCloningVisitor):
def __init__(self, parent, relationship):
self.parent = parent
self.relationship = relationship

"""Take an expression and convert the columns to the version_table's columns"""
def replace(self, column):
if not isinstance(column, sa.Column):
return
Expand All @@ -18,16 +15,34 @@ def replace(self, column):
reflected_column = column
else:
reflected_column = table.c[column.name]
if (
column in self.relationship.local_columns and
table == self.parent.__table__
):
reflected_column = bindparam(
column.key,
getattr(self.parent, column.key)
)

return reflected_column

def __call__(self, expr):
return self.traverse(expr)


class RelationshipPrimaryJoinReflector(VersionExpressionReflector):
"""
Takes a relationship and modifies it to handle the primaryjoin of the relationship
"""
def __init__(self, parent, relationship):
self.parent = parent
self.relationship = relationship

def replace(self, column):
reflected_column = super().replace(column)
if reflected_column is None:
return

if (
column in self.relationship.local_columns and
reflected_column.table == self.parent.__table__
):
# Keep the columns from the self.parent.__table__ as is
reflected_column = bindparam(
column.key,
getattr(self.parent, column.key)
)

return reflected_column
12 changes: 6 additions & 6 deletions sqlalchemy_continuum/relationship_builder.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sqlalchemy as sa

from .exc import ClassNotVersioned
from .expression_reflector import VersionExpressionReflector
from .expression_reflector import RelationshipPrimaryJoinReflector
from .operation import Operation
from .table_builder import TableBuilder
from .utils import adapt_columns, version_class, option
Expand Down Expand Up @@ -46,7 +46,7 @@ def one_to_many_subquery(self, obj):

def many_to_one_subquery(self, obj):
tx_column = option(obj, 'transaction_column_name')
reflector = VersionExpressionReflector(obj, self.property)
reflector = RelationshipPrimaryJoinReflector(obj, self.property)

return getattr(self.remote_cls, tx_column) == (
sa.select(
Expand Down Expand Up @@ -93,7 +93,7 @@ def criteria(self, obj):
elif direction.name == 'MANYTOONE':
return self.many_to_one_criteria(obj)
else:
reflector = VersionExpressionReflector(obj, self.property)
reflector = RelationshipPrimaryJoinReflector(obj, self.property)
return reflector(self.property.primaryjoin)

def many_to_many_criteria(self, obj):
Expand Down Expand Up @@ -171,7 +171,7 @@ def many_to_one_criteria(self, obj):
AND operation_type != 2

"""
reflector = VersionExpressionReflector(obj, self.property)
reflector = RelationshipPrimaryJoinReflector(obj, self.property)
return sa.and_(
reflector(self.property.primaryjoin),
self.many_to_one_subquery(obj),
Expand Down Expand Up @@ -209,7 +209,7 @@ def one_to_many_criteria(self, obj):
)

"""
reflector = VersionExpressionReflector(obj, self.property)
reflector = RelationshipPrimaryJoinReflector(obj, self.property)
return sa.and_(
reflector(self.property.primaryjoin),
self.one_to_many_subquery(obj),
Expand Down Expand Up @@ -263,7 +263,7 @@ def association_subquery(self, obj):
tx_column = option(obj, 'transaction_column_name')
join_column = self.property.primaryjoin.right.name
object_join_column = self.property.primaryjoin.left.name
reflector = VersionExpressionReflector(obj, self.property)
reflector = RelationshipPrimaryJoinReflector(obj, self.property)

association_table_alias = self.association_version_table.alias()
association_cols = [
Expand Down
7 changes: 4 additions & 3 deletions sqlalchemy_continuum/version.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import sqlalchemy as sa

from .reverter import Reverter
from .utils import get_versioning_manager, is_internal_column, parent_class
from .utils import get_versioning_manager, is_internal_column, is_table_column, parent_class


class VersionClassBase(object):
Expand Down Expand Up @@ -52,8 +52,9 @@ def changeset(self):
previous_version = self.previous
data = {}

for key in sa.inspect(self.__class__).columns.keys():
if is_internal_column(self, key):
for key, column in sa.inspect(self.__class__).columns.items():
if is_internal_column(self, key) or not is_table_column(column):
# Ignore internal columns and column_property() which are expressions
continue
if not previous_version:
old = None
Expand Down
13 changes: 13 additions & 0 deletions tests/builders/test_model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,16 @@ def test_builds_relationship(self):

def test_parent_has_access_to_versioning_manager(self):
assert self.Article.__versioning_manager__


def test_column_properties(self):
article = self.Article()
article.name = u'Name'
article.content = u'Content'
article.description = u'Desc'
self.session.add(article)
self.session.commit()

article_version = article.versions[0]
assert article.fulltext_content == article.name + article.content + article.description
assert article.fulltext_content == article_version.fulltext_content