Skip to content

Commit

Permalink
feat(#24): Allow tables in get_versioning_manager()
Browse files Browse the repository at this point in the history
Add versioning related information to Tables and Models.
Now, Tables will also have:
 - __versioning_manager__ - So we can keep track of the manager used for a table.
 - __version_parent__ - So we can keep track of the parent

Using this, we can now allow table in get_versioning_manager()
Also add unittests for the function
  • Loading branch information
AbdealiLoKo committed Nov 17, 2022
1 parent 1e7e561 commit 37e5ae3
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 9 deletions.
5 changes: 4 additions & 1 deletion sqlalchemy_history/table_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,12 +127,15 @@ def __call__(self, extends=None):
"""
Builds version table.
"""
self.parent_table.__versioning_manager__ = self.manager
columns = self.columns if extends is None else []
self.manager.plugins.after_build_version_table_columns(self, columns)
return sa.schema.Table(
version_table = sa.schema.Table(
extends.name if extends is not None else self.table_name,
self.parent_table.metadata,
*columns,
schema=self.parent_table.schema,
extend_existing=extends is not None,
)
version_table.__versioning_manager__ = self.manager
return version_table
34 changes: 26 additions & 8 deletions sqlalchemy_history/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,38 @@
from sqlalchemy_history.exc import ClassNotVersioned


def get_versioning_manager(obj_or_class):
def get_versioning_manager(item):
"""
Return the associated SQLAlchemy-Continuum VersioningManager for given
SQLAlchemy declarative model class or object.
SQLAlchemy declarative model class or object or table.
:param obj_or_class: SQLAlchemy declarative model object or class
:param item: An item from SQLAlchemy. Can be:
- A declarative ORM object
- A declarative ORM class
- An instance of a SQL table
"""
if isinstance(obj_or_class, AliasedClass):
obj_or_class = sa.inspect(obj_or_class).mapper.class_
cls = obj_or_class if isclass(obj_or_class) else obj_or_class.__class__
# The ORM class or SQL table on which versioning was enabled
versioned_item = None
if isclass(item):
versioned_item = item
else:
if isinstance(item, AliasedClass):
versioned_item = sa.inspect(item).mapper.class_
elif isinstance(item, sa.Table):
versioned_item = item
else:
versioned_item = item.__class__

try:
return cls.__versioning_manager__
return versioned_item.__versioning_manager__
except AttributeError:
raise ClassNotVersioned(cls.__name__)
if isinstance(versioned_item, sa.Table):
name = 'Table "%s"' % versioned_item.name
else:
name = versioned_item.__name__
# NOTE: We say ClassNotVersioned - but it can also throw an error for a table.
# Maybe we want to make this exc more generic ?
raise ClassNotVersioned(name)


def option(obj_or_class, option_name):
Expand Down
73 changes: 73 additions & 0 deletions tests/utils/test_get_versioning_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from copy import copy
from pytest import raises
import sqlalchemy as sa
from sqlalchemy_history import versioning_manager
from sqlalchemy_history.exc import ClassNotVersioned
from sqlalchemy_history.utils import get_versioning_manager

from tests import TestCase


class TestGetVersioningManager(TestCase):
def create_models(self):
"""
Creates many-to-many relationship between Article and Tag
Article is versioned. But Tag is not versioned
"""

class Article(self.Model):
__tablename__ = "article"
__versioned__ = copy(self.options)

id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))

article_tag = sa.Table(
"article_tag",
self.Model.metadata,
sa.Column(
"article_id",
sa.Integer,
sa.ForeignKey("article.id", ondelete="CASCADE"),
primary_key=True,
),
sa.Column("tag_id", sa.Integer, sa.ForeignKey("tag.id", ondelete="CASCADE"), primary_key=True),
)

class Tag(self.Model):
__tablename__ = "tag"

id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
articles = sa.orm.relationship(Article, secondary=article_tag, backref="tags")

self.Article = Article
self.article_tag = article_tag
self.Tag = Tag

def test_parent_class(self):
assert get_versioning_manager(self.Article) == versioning_manager

def test_parent_table(self):
assert get_versioning_manager(self.Article.__table__) == versioning_manager

def test_version_class(self):
assert get_versioning_manager(self.ArticleVersion) == versioning_manager

def test_version_table(self):
assert get_versioning_manager(self.ArticleVersion.__table__) == versioning_manager

def test_association_table(self):
assert get_versioning_manager(self.article_tag) == versioning_manager

def test_aliased_class(self):
assert get_versioning_manager(sa.orm.aliased(self.Article)) == versioning_manager
assert get_versioning_manager(sa.orm.aliased(self.ArticleVersion)) == versioning_manager

def test_unknown_class(self):
with raises(ClassNotVersioned):
get_versioning_manager(self.Tag)

def test_unknown_table(self):
with raises(ClassNotVersioned):
get_versioning_manager(self.Tag.__table__)

0 comments on commit 37e5ae3

Please sign in to comment.