diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 412bdb1a2..8ae7b68c4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -9,6 +9,12 @@ Changelog 0.21 ==== +0.21.4 +------ +Fixed +^^^^^ +- Fix `update_or_create` errors when field value changed. (#1584) + 0.21.3 ------ Fixed diff --git a/tests/test_model_methods.py b/tests/test_model_methods.py index 1e0b6a9be..4df2deaf6 100644 --- a/tests/test_model_methods.py +++ b/tests/test_model_methods.py @@ -90,9 +90,9 @@ async def test_implicit_clone_pk_required_none(self): class TestModelMethods(test.TestCase): async def asyncSetUp(self): await super().asyncSetUp() - self.mdl = await Tournament.create(name="Test") - self.mdl2 = Tournament(name="Test") self.cls = Tournament + self.mdl = await self.cls.create(name="Test") + self.mdl2 = self.cls(name="Test") async def test_save(self): oldid = self.mdl.id @@ -176,6 +176,36 @@ async def test_update_or_create(self): mdl2 = await self.cls.get(name="Test2") self.assertEqual(mdl, mdl2) + async def test_update_or_create_with_defaults(self): + mdl = await self.cls.get(name=self.mdl.name) + mdl_dict = dict(mdl) + oldid = mdl.id + mdl.id = 135 + with self.assertRaisesRegex(ParamsError, "Conflict value with key='id':"): + # Missing query: check conflict with kwargs and defaults before create + await self.cls.update_or_create(id=mdl.id, defaults=mdl_dict) + desc = str(uuid4()) + # If there is no conflict with defaults and kwargs, it will be success to update or create + defaults = dict(mdl_dict, desc=desc) + kwargs = {"id": defaults["id"], "name": defaults["name"]} + mdl, created = await self.cls.update_or_create(defaults, **kwargs) + self.assertFalse(created) + self.assertEqual(defaults["desc"], mdl.desc) + self.assertNotEqual(self.mdl.desc, mdl.desc) + # Hint query: use defauts to update without checking conflict + mdl2, created = await self.cls.update_or_create( + id=oldid, desc=desc, defaults=dict(mdl_dict, desc="new desc") + ) + self.assertFalse(created) + self.assertNotEqual(dict(mdl), dict(mdl2)) + # Missing query: success to create if no conflict + not_exist_name = str(uuid4()) + no_conflict_defaults = {"name": not_exist_name, "desc": desc} + no_conflict_kwargs = {"name": not_exist_name} + mdl, created = await self.cls.update_or_create(no_conflict_defaults, **no_conflict_kwargs) + self.assertTrue(created) + self.assertEqual(not_exist_name, mdl.name) + async def test_first(self): mdl = await self.cls.first() self.assertEqual(self.mdl.id, mdl.id) diff --git a/tortoise/models.py b/tortoise/models.py index 39bf9e45d..bf1a31729 100644 --- a/tortoise/models.py +++ b/tortoise/models.py @@ -1026,6 +1026,8 @@ async def get_or_create( :param using_db: Specific DB connection to use instead of default bound :param kwargs: Query parameters. :raises IntegrityError: If create failed + :raises TransactionManagementError: If transaction error + :raises ParamsError: If defaults conflict with kwargs """ if not defaults: defaults = {} @@ -1033,15 +1035,26 @@ async def get_or_create( try: return await cls.filter(**kwargs).using_db(db).get(), False except DoesNotExist: + return await cls._create_or_get(db, defaults, **kwargs) + + @classmethod + async def _create_or_get( + cls, db: BaseDBAsyncClient, defaults: dict, **kwargs + ) -> Tuple[Self, bool]: + """Try to create, if fails with IntegrityError then try to get""" + for key in defaults.keys() & kwargs.keys(): + if (default_value := defaults[key]) != (query_value := kwargs[key]): + raise ParamsError(f"Conflict value with {key=}: {default_value=} vs {query_value=}") + merged_defaults = {**kwargs, **defaults} + try: + async with in_transaction(connection_name=db.connection_name) as connection: + return await cls.create(using_db=connection, **merged_defaults), True + except IntegrityError as exc: try: - async with in_transaction(connection_name=db.connection_name) as connection: - return await cls.create(using_db=connection, **defaults, **kwargs), True - except IntegrityError as exc: - try: - return await cls.filter(**kwargs).using_db(db).get(), False - except DoesNotExist: - pass - raise exc + return await cls.filter(**kwargs).using_db(db).get(), False + except DoesNotExist: + pass + raise exc @classmethod def select_for_update( @@ -1084,7 +1097,7 @@ async def update_or_create( if instance: await instance.update_from_dict(defaults).save(using_db=connection) return instance, False - return await cls.get_or_create(defaults, db, **kwargs) + return await cls._create_or_get(db, defaults, **kwargs) @classmethod async def create(