diff --git a/sqlalchemy_history/operation.py b/sqlalchemy_history/operation.py index fea54fd..cb328a8 100644 --- a/sqlalchemy_history/operation.py +++ b/sqlalchemy_history/operation.py @@ -90,7 +90,11 @@ def add_update(self, target): del state_copy[rel_key] if state_copy: - self.add(Operation(target, Operation.UPDATE)) - + if target in self: + # If already in current transaction and some event hook did a update + # prior to commit hook, continue with operation type as it is + self.add(Operation(target, self[self.format_key(target)].type)) + else: + self.add(Operation(target, Operation.UPDATE)) def add_delete(self, target): self.add(Operation(target, Operation.DELETE)) diff --git a/tests/reported_bugs/test_bug_141_after_flush_postexec_op_type_issue.py b/tests/reported_bugs/test_bug_141_after_flush_postexec_op_type_issue.py new file mode 100644 index 0000000..74395ce --- /dev/null +++ b/tests/reported_bugs/test_bug_141_after_flush_postexec_op_type_issue.py @@ -0,0 +1,43 @@ +import sqlalchemy as sa +from copy import copy + +from tests import TestCase +from sqlalchemy_history import version_class + + +class TestBug141(TestCase): + # ref: https://github.com/corridor/sqlalchemy-history/issues/141 + def create_models(self): + class Author(self.Model): + __tablename__ = "author" + __versioned__ = copy(self.options) + + id = sa.Column( + sa.Integer, sa.Sequence(f"{__tablename__}_seq", start=1), autoincrement=True, primary_key=True + ) + name = sa.Column(sa.Unicode(255)) + + self.Author = Author + + def test_add_record(self): + author = self.Author(name="Author 1") + @sa.event.listens_for(self.session, 'after_flush_postexec') + def after_flush_postexec(session, flush_context): + if author.name != "yoyoyoyoyo": + author.name = "yoyoyoyoyo" + self.session.add(author) + self.session.commit() + + versioned_objs = self.session.query(version_class(self.Author)).all() + assert len(versioned_objs) == 1 + assert versioned_objs[0].operation_type == 0 + assert versioned_objs[0].name == "yoyoyoyoyo" + author.name = "sdfeoinfe" + self.session.add(author) + self.session.commit() + versioned_objs = self.session.query(version_class(self.Author)).all() + assert len(versioned_objs) == 2 + assert versioned_objs[0].operation_type == 0 + assert versioned_objs[1].operation_type == 1 + assert versioned_objs[0].name == versioned_objs[1].name == "yoyoyoyoyo" + sa.event.remove(self.session, "after_flush_postexec", after_flush_postexec) diff --git a/tests/test_exotic_operation_combos.py b/tests/test_exotic_operation_combos.py index 11f388b..d2fab17 100644 --- a/tests/test_exotic_operation_combos.py +++ b/tests/test_exotic_operation_combos.py @@ -1,6 +1,3 @@ -import os -from pytest import mark - from sqlalchemy_history.operation import Operation from tests import TestCase, create_test_cases @@ -40,10 +37,6 @@ def test_insert_deleted_and_flushed_object(self): assert article2.versions[0].operation_type == Operation.INSERT assert article2.versions[1].operation_type == Operation.UPDATE - # Ref for mssql: https://github.com/sqlalchemy/sqlalchemy/discussions/8829 - @mark.skipif( - os.environ.get("DB") == "mssql", reason="mssql does not support changing the IDENTITY column" - ) def test_replace_deleted_object_with_update(self): article = self.Article() article.name = "Some article" @@ -58,7 +51,7 @@ def test_replace_deleted_object_with_update(self): self.session.delete(article) self.session.flush() - article2.id = article.id + article2.name = article.name self.session.commit() assert article2.versions.count() == 2 assert article2.versions[0].operation_type == Operation.INSERT