diff --git a/sqlalchemy_continuum/builder.py b/sqlalchemy_continuum/builder.py index 53134029..21e20c0d 100644 --- a/sqlalchemy_continuum/builder.py +++ b/sqlalchemy_continuum/builder.py @@ -3,6 +3,8 @@ from functools import wraps import sqlalchemy as sa +from sqlalchemy_continuum.expression_reflector import VersionExpressionReflector +from sqlalchemy_continuum.utils import is_table_column from sqlalchemy_utils.functions import get_declarative_base from .dialects.postgresql import create_versioning_trigger_listeners @@ -189,6 +191,7 @@ def configure_versioned_classes(self): self.build_relationships(pending_classes_copies) self.enable_active_history(pending_classes_copies) self.create_column_aliases(pending_classes_copies) + self.create_column_properties(pending_classes_copies) def enable_active_history(self, version_classes): """ @@ -220,3 +223,23 @@ def create_column_aliases(self, version_classes): continue version_class_mapper.add_property(key, sa.orm.column_property(version_class_column)) + + def create_column_properties(self, version_classes): + """ + Create equivalent column_property() on the version class (as it is on the parent model) + + This does not handle the simple column aliases - just expressions + """ + for cls in version_classes: + model_mapper = sa.inspect(cls) + version_class = self.manager.version_class_map.get(cls) + if not version_class: + continue + + version_class_mapper = sa.inspect(version_class) + reflector = VersionExpressionReflector() + for key, column in model_mapper.columns.items(): + if is_table_column(column): # We ignore simple table columns + continue + version_column = reflector(column) + version_class_mapper.add_property(key, sa.orm.column_property(version_column)) diff --git a/sqlalchemy_continuum/expression_reflector.py b/sqlalchemy_continuum/expression_reflector.py index fafb14f6..10c4fff5 100644 --- a/sqlalchemy_continuum/expression_reflector.py +++ b/sqlalchemy_continuum/expression_reflector.py @@ -5,10 +5,7 @@ class VersionExpressionReflector(sa.sql.visitors.ReplacingCloningVisitor): - def __init__(self, parent, relationship): - self.parent = parent - self.relationship = relationship - + """Take an expression and convert the columns to the version_table's columns""" def replace(self, column): if not isinstance(column, sa.Column): return @@ -18,16 +15,34 @@ def replace(self, column): reflected_column = column else: reflected_column = table.c[column.name] - if ( - column in self.relationship.local_columns and - table == self.parent.__table__ - ): - reflected_column = bindparam( - column.key, - getattr(self.parent, column.key) - ) return reflected_column def __call__(self, expr): return self.traverse(expr) + + +class RelationshipPrimaryJoinReflector(VersionExpressionReflector): + """ + Takes a relationship and modifies it to handle the primaryjoin of the relationship + """ + def __init__(self, parent, relationship): + self.parent = parent + self.relationship = relationship + + def replace(self, column): + reflected_column = super().replace(column) + if reflected_column is None: + return + + if ( + column in self.relationship.local_columns and + reflected_column.table == self.parent.__table__ + ): + # Keep the columns from the self.parent.__table__ as is + reflected_column = bindparam( + column.key, + getattr(self.parent, column.key) + ) + + return reflected_column diff --git a/sqlalchemy_continuum/relationship_builder.py b/sqlalchemy_continuum/relationship_builder.py index e4888a24..62437b88 100644 --- a/sqlalchemy_continuum/relationship_builder.py +++ b/sqlalchemy_continuum/relationship_builder.py @@ -1,7 +1,7 @@ import sqlalchemy as sa from .exc import ClassNotVersioned -from .expression_reflector import VersionExpressionReflector +from .expression_reflector import RelationshipPrimaryJoinReflector from .operation import Operation from .table_builder import TableBuilder from .utils import adapt_columns, version_class, option @@ -46,7 +46,7 @@ def one_to_many_subquery(self, obj): def many_to_one_subquery(self, obj): tx_column = option(obj, 'transaction_column_name') - reflector = VersionExpressionReflector(obj, self.property) + reflector = RelationshipPrimaryJoinReflector(obj, self.property) return getattr(self.remote_cls, tx_column) == ( sa.select( @@ -93,7 +93,7 @@ def criteria(self, obj): elif direction.name == 'MANYTOONE': return self.many_to_one_criteria(obj) else: - reflector = VersionExpressionReflector(obj, self.property) + reflector = RelationshipPrimaryJoinReflector(obj, self.property) return reflector(self.property.primaryjoin) def many_to_many_criteria(self, obj): @@ -171,7 +171,7 @@ def many_to_one_criteria(self, obj): AND operation_type != 2 """ - reflector = VersionExpressionReflector(obj, self.property) + reflector = RelationshipPrimaryJoinReflector(obj, self.property) return sa.and_( reflector(self.property.primaryjoin), self.many_to_one_subquery(obj), @@ -209,7 +209,7 @@ def one_to_many_criteria(self, obj): ) """ - reflector = VersionExpressionReflector(obj, self.property) + reflector = RelationshipPrimaryJoinReflector(obj, self.property) return sa.and_( reflector(self.property.primaryjoin), self.one_to_many_subquery(obj), @@ -263,7 +263,7 @@ def association_subquery(self, obj): tx_column = option(obj, 'transaction_column_name') join_column = self.property.primaryjoin.right.name object_join_column = self.property.primaryjoin.left.name - reflector = VersionExpressionReflector(obj, self.property) + reflector = RelationshipPrimaryJoinReflector(obj, self.property) association_table_alias = self.association_version_table.alias() association_cols = [ diff --git a/sqlalchemy_continuum/version.py b/sqlalchemy_continuum/version.py index d71e745d..786055f8 100644 --- a/sqlalchemy_continuum/version.py +++ b/sqlalchemy_continuum/version.py @@ -1,7 +1,7 @@ import sqlalchemy as sa from .reverter import Reverter -from .utils import get_versioning_manager, is_internal_column, parent_class +from .utils import get_versioning_manager, is_internal_column, is_table_column, parent_class class VersionClassBase(object): @@ -52,8 +52,9 @@ def changeset(self): previous_version = self.previous data = {} - for key in sa.inspect(self.__class__).columns.keys(): - if is_internal_column(self, key): + for key, column in sa.inspect(self.__class__).columns.items(): + if is_internal_column(self, key) or not is_table_column(column): + # Ignore internal columns and column_property() which are expressions continue if not previous_version: old = None diff --git a/tests/builders/test_model_builder.py b/tests/builders/test_model_builder.py index 7c93aac0..a0c4dfcc 100644 --- a/tests/builders/test_model_builder.py +++ b/tests/builders/test_model_builder.py @@ -7,3 +7,16 @@ def test_builds_relationship(self): def test_parent_has_access_to_versioning_manager(self): assert self.Article.__versioning_manager__ + + + def test_column_properties(self): + article = self.Article() + article.name = u'Name' + article.content = u'Content' + article.description = u'Desc' + self.session.add(article) + self.session.commit() + + article_version = article.versions[0] + assert article.fulltext_content == article.name + article.content + article.description + assert article.fulltext_content == article_version.fulltext_content