From 86e0507b190bc9492952a7030a4aedb7ba11a319 Mon Sep 17 00:00:00 2001 From: Paul Gilmartin Date: Tue, 11 Apr 2023 15:35:10 +0200 Subject: [PATCH] Fix deserlaization of trigger channels --- docs/channels.rst | 3 + manage.py | 0 pgpubsub/channel.py | 106 +++++++++----- .../migrations/0002_auto_20230411_0829.py | 24 ++++ pgpubsub/tests/models.py | 11 +- pgpubsub/tests/test_core.py | 106 +------------- pgpubsub/tests/test_deserialize.py | 80 +++++++++++ pgpubsub/tests/test_trigger_deserialize.py | 135 ++++++++++++++++++ 8 files changed, 324 insertions(+), 141 deletions(-) mode change 100644 => 100755 manage.py create mode 100644 pgpubsub/tests/migrations/0002_auto_20230411_0829.py create mode 100644 pgpubsub/tests/test_deserialize.py create mode 100644 pgpubsub/tests/test_trigger_deserialize.py diff --git a/docs/channels.rst b/docs/channels.rst index d4412b4..a695d0a 100644 --- a/docs/channels.rst +++ b/docs/channels.rst @@ -76,6 +76,9 @@ all trigger-based notifications: Here the ``old`` and ``new`` parameters are the (unsaved) versions of what the trigger invoking instance looked like before and after the trigger was invoked. +These objects are built by passing in the trigger notification payload through +Django's model `deserializers `__. + In this example, ``old`` would refer to the state of our ``Author`` object pre-creation (and would hence be ``None``) and ``new`` would refer to a copy of the newly created ``Author`` instance. This payload is inspired by the ``OLD`` diff --git a/manage.py b/manage.py old mode 100644 new mode 100755 diff --git a/pgpubsub/channel.py b/pgpubsub/channel.py index 059fadd..d99ae1b 100644 --- a/pgpubsub/channel.py +++ b/pgpubsub/channel.py @@ -9,8 +9,8 @@ from typing import Callable, Dict, Union, List from django.apps import apps +from django.core import serializers from django.db import models -from django.db.models import fields registry = defaultdict(list) @@ -134,36 +134,6 @@ def _deserialize_arg(cls, arg, arg_type): return arg_type(arg) -class TriggerPayload: - def __init__(self, payload: Dict): - self._json_payload = payload - self._model = apps.get_model( - app_label=self._json_payload['app'], - model_name=self._json_payload['model'], - ) - self._old_row_data = self._json_payload['old'] - self._new_row_data = self._json_payload['new'] - - @property - def old(self): - if self._old_row_data: - return self._entity_from_json(self._model, self._old_row_data) - - @property - def new(self): - if self._new_row_data: - return self._entity_from_json(self._model, self._new_row_data) - - def _entity_from_json(self, model, model_payload): - data = {} - for k, v in model_payload.items(): - if isinstance(model._meta.get_field(k), fields.DateTimeField): - data[k] = datetime.datetime.fromisoformat(v) - else: - data[k] = v - return model(**data) - - @dataclass class TriggerChannel(BaseChannel): @@ -173,9 +143,77 @@ class TriggerChannel(BaseChannel): @classmethod def deserialize(cls, payload: Union[Dict, str]): - payload = super().deserialize(payload) - trigger_payload = TriggerPayload(payload) - return {'old': trigger_payload.old, 'new': trigger_payload.new} + payload_dict = super().deserialize(payload) + old_model_data = cls._build_model_serializer_data( + payload_dict, state='old') + new_model_data = cls._build_model_serializer_data( + payload_dict, state='new') + + old_deserialized_objects = serializers.deserialize( + 'json', + json.dumps(old_model_data), + ignorenonexistent=True, + ) + new_deserialized_objects = serializers.deserialize( + 'json', + json.dumps(new_model_data), + ignorenonexistent=True, + ) + + old = next(old_deserialized_objects, None) + if old is not None: + old = old.object + new = next(new_deserialized_objects, None) + if new is not None: + new = new.object + return {'old': old, 'new': new} + + @classmethod + def _build_model_serializer_data(cls, payload: Dict, state: str): + """Reformat serialized data into shape as expected + by the Django model deserializer. + """ + app = payload['app'] + model_name = payload['model'] + model_cls = apps.get_model( + app_label=payload['app'], + model_name=payload['model'], + ) + fields = { + field.name: field for field in model_cls._meta.fields + } + db_fields = { + field.db_column: field for field in model_cls._meta.fields + } + + original_state = payload[state] + new_state = {} + model_data = [] + if payload[state] is not None: + for field in list(original_state): + # Triggers serialize the notification payload with + # respect to how the model fields look as columns + # in the database. We therefore need to take + # care to map xxx_id named columns to the corresponding + # xxx model field and also to account for model fields + # with alternative database column names as declared + # by the db_column attribute. + value = original_state.pop(field) + if field.endswith('_id'): + field = field.rsplit('_id')[0] + if field in fields: + new_state[field] = value + elif field in db_fields: + field = db_fields[field].name + new_state[field] = value + + model_data.append( + {'fields': new_state, + 'id': new_state['id'], + 'model': f'{app}.{model_name}', + }, + ) + return model_data def locate_channel(channel): diff --git a/pgpubsub/tests/migrations/0002_auto_20230411_0829.py b/pgpubsub/tests/migrations/0002_auto_20230411_0829.py new file mode 100644 index 0000000..a1d56f2 --- /dev/null +++ b/pgpubsub/tests/migrations/0002_auto_20230411_0829.py @@ -0,0 +1,24 @@ +# Generated by Django 3.2.12 on 2023-04-11 08:29 + +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('tests', '0001_initial'), + ] + + operations = [ + migrations.AddField( + model_name='author', + name='alternative_name', + field=models.TextField(db_column='other', null=True), + ), + migrations.AlterField( + model_name='author', + name='profile_picture', + field=models.ForeignKey(db_column='picture', null=True, on_delete=django.db.models.deletion.PROTECT, to='tests.media'), + ), + ] diff --git a/pgpubsub/tests/models.py b/pgpubsub/tests/models.py index b871d25..42f285b 100644 --- a/pgpubsub/tests/models.py +++ b/pgpubsub/tests/models.py @@ -15,11 +15,18 @@ class Media(models.Model): class Author(models.Model): - user = models.ForeignKey(User, on_delete=models.PROTECT, null=True) + user = models.ForeignKey( + User, on_delete=models.PROTECT, null=True) name = models.TextField() age = models.IntegerField(null=True) active = models.BooleanField(default=True) - profile_picture = models.ForeignKey(Media, null=True, on_delete=models.PROTECT) + profile_picture = models.ForeignKey( + Media, + null=True, + on_delete=models.PROTECT, + db_column='picture', + ) + alternative_name = models.TextField(db_column='other', null=True) class Post(models.Model): diff --git a/pgpubsub/tests/test_core.py b/pgpubsub/tests/test_core.py index 1a780c1..45de0f4 100644 --- a/pgpubsub/tests/test_core.py +++ b/pgpubsub/tests/test_core.py @@ -1,123 +1,19 @@ -from dataclasses import dataclass import datetime -import json -from typing import Dict, List, Set, Tuple from django.db.transaction import atomic import pytest -from pgpubsub.channel import Channel, TriggerChannel from pgpubsub.listen import listen_to_channels, process_notifications, listen from pgpubsub.models import Notification from pgpubsub.notify import process_stored_notifications from pgpubsub.tests.channels import ( AuthorTriggerChannel, MediaTriggerChannel, - PostTriggerChannel, ) -from pgpubsub.tests.listeners import post_reads_per_date_cache, scan_media +from pgpubsub.tests.listeners import post_reads_per_date_cache from pgpubsub.tests.models import Author, Media, Post -def test_deserialize_1(): - @dataclass - class MyChannel(Channel): - arg1: str - arg2: Dict[int, int] - default_arg1: float = 0.0 - - deserialized = _deserialize(MyChannel, arg1='1', arg2={1: 2}, default_arg1=3.4) - assert {'arg1': '1', 'arg2': {1: 2}, - 'default_arg1': 3.4} == deserialized - - -def test_deserialize_2(): - @dataclass - class MyChannel(Channel): - arg1: Dict[str, bool] - default_arg1: bool = False - default_arg2: int = 0 - - deserialized = _deserialize( - MyChannel, arg1={'Paul': False}, default_arg1=True) - assert {'arg1': {'Paul': False}, - 'default_arg1': True, - 'default_arg2': 0} == deserialized - - -def test_deserialize_3(): - @dataclass - class MyChannel(Channel): - arg1: datetime.date - arg2: Dict[datetime.date, bool] - arg3: Dict[str, datetime.datetime] - - deserialized = _deserialize( - MyChannel, - arg1=datetime.date(2021, 1, 1), - arg2={ - datetime.date(2021, 1, 7): True, - datetime.date(2021, 1, 17): False, - }, - arg3={'chosen_date': datetime.datetime(2021, 1, 1, 9, 30)}, - ) - - assert { - 'arg1': datetime.date(2021, 1, 1), - 'arg2': {datetime.date(2021, 1, 7): True, datetime.date(2021, 1, 17): False}, - 'arg3': {'chosen_date': datetime.datetime(2021, 1, 1, 9, 30)}, - } == deserialized - - -def test_deserialize_trigger_channel(): - @dataclass - class MyChannel(TriggerChannel): - model: Post - - some_datetime = datetime.datetime.utcnow() - post = Post(content='some-content', date=some_datetime) - deserialized = MyChannel.deserialize( - json.dumps( - { - 'app': 'tests', - 'model': 'Post', - 'old': None, - 'new': {'content': 'some-content', 'date': some_datetime.isoformat()}, - } - ) - ) - assert deserialized['new'].date == some_datetime - assert deserialized['new'].content == post.content - assert deserialized['new'].id == post.id - assert deserialized['new'].rating == post.rating - assert deserialized['new'].author == post.author - - -def _deserialize(channel_cls, **kwargs): - serialized = channel_cls(**kwargs).serialize() - return channel_cls.deserialize(serialized) - - -def test_deserialize_4(): - @dataclass - class MyChannel(Channel): - arg1: List[datetime.date] - arg2: Set[float] - arg3: Tuple[str] - - deserialized = _deserialize( - MyChannel, - arg1=[datetime.date(2021, 1, 1), datetime.date(2021, 1, 2)], - arg2={1.0, 2.1}, - arg3=('hello', 'world'), - ) - assert { - 'arg1': [datetime.date(2021, 1, 1), datetime.date(2021, 1, 2)], - 'arg2': {1.0, 2.1}, - 'arg3': ('hello', 'world'), - } == deserialized - - @pytest.fixture() def pg_connection(): return listen_to_channels() diff --git a/pgpubsub/tests/test_deserialize.py b/pgpubsub/tests/test_deserialize.py new file mode 100644 index 0000000..4c930a3 --- /dev/null +++ b/pgpubsub/tests/test_deserialize.py @@ -0,0 +1,80 @@ +from dataclasses import dataclass +import datetime +from typing import Dict, List, Set, Tuple + +from pgpubsub.channel import Channel + + +def test_deserialize_1(): + @dataclass + class MyChannel(Channel): + arg1: str + arg2: Dict[int, int] + default_arg1: float = 0.0 + + deserialized = _deserialize(MyChannel, arg1='1', arg2={1: 2}, default_arg1=3.4) + assert {'arg1': '1', 'arg2': {1: 2}, + 'default_arg1': 3.4} == deserialized + + +def test_deserialize_2(): + @dataclass + class MyChannel(Channel): + arg1: Dict[str, bool] + default_arg1: bool = False + default_arg2: int = 0 + + deserialized = _deserialize( + MyChannel, arg1={'Paul': False}, default_arg1=True) + assert {'arg1': {'Paul': False}, + 'default_arg1': True, + 'default_arg2': 0} == deserialized + + +def test_deserialize_3(): + @dataclass + class MyChannel(Channel): + arg1: datetime.date + arg2: Dict[datetime.date, bool] + arg3: Dict[str, datetime.datetime] + + deserialized = _deserialize( + MyChannel, + arg1=datetime.date(2021, 1, 1), + arg2={ + datetime.date(2021, 1, 7): True, + datetime.date(2021, 1, 17): False, + }, + arg3={'chosen_date': datetime.datetime(2021, 1, 1, 9, 30)}, + ) + + assert { + 'arg1': datetime.date(2021, 1, 1), + 'arg2': {datetime.date(2021, 1, 7): True, datetime.date(2021, 1, 17): False}, + 'arg3': {'chosen_date': datetime.datetime(2021, 1, 1, 9, 30)}, + } == deserialized + + +def test_deserialize_4(): + @dataclass + class MyChannel(Channel): + arg1: List[datetime.date] + arg2: Set[float] + arg3: Tuple[str] + + deserialized = _deserialize( + MyChannel, + arg1=[datetime.date(2021, 1, 1), datetime.date(2021, 1, 2)], + arg2={1.0, 2.1}, + arg3=('hello', 'world'), + ) + assert { + 'arg1': [datetime.date(2021, 1, 1), datetime.date(2021, 1, 2)], + 'arg2': {1.0, 2.1}, + 'arg3': ('hello', 'world'), + } == deserialized + + +def _deserialize(channel_cls, **kwargs): + serialized = channel_cls(**kwargs).serialize() + return channel_cls.deserialize(serialized) diff --git a/pgpubsub/tests/test_trigger_deserialize.py b/pgpubsub/tests/test_trigger_deserialize.py new file mode 100644 index 0000000..9c8e055 --- /dev/null +++ b/pgpubsub/tests/test_trigger_deserialize.py @@ -0,0 +1,135 @@ +import datetime +import json +from dataclasses import dataclass + +import pytest +from django.contrib.auth.models import User +from django.utils import timezone + +from pgpubsub import TriggerChannel +from pgpubsub.models import Notification +from pgpubsub.tests.channels import ( + AuthorTriggerChannel, + MediaTriggerChannel, + PostTriggerChannel, +) +from pgpubsub.tests.models import Post, Author, Media + + +def test_deserialize_post_trigger_channel(): + @dataclass + class MyChannel(TriggerChannel): + model: Post + + some_datetime = datetime.datetime.utcnow() + post = Post(content='some-content', date=some_datetime, pk=1) + + deserialized = MyChannel.deserialize( + json.dumps( + { + 'app': 'tests', + 'model': 'Post', + 'old': None, + 'new': { + 'content': 'some-content', + 'date': some_datetime.isoformat(), + 'id': post.pk, + # See https://github.com/Opus10/django-pgpubsub/issues/29 + 'old_field': 'foo', + }, + } + ) + ) + assert deserialized['new'].date == some_datetime + assert deserialized['new'].content == post.content + assert deserialized['new'].rating == post.rating + assert deserialized['new'].author == post.author + + +@pytest.mark.django_db(transaction=True) +def test_deserialize_insert_payload(): + user = User.objects.create(username='Billy') + media = Media.objects.create( + name='avatar.jpg', + content_type='image/png', + size=15000, + ) + author = Author.objects.create( + name='Billy', + user=user, + alternative_name='Jimmy', + profile_picture=media, + ) + # Notification comes from the AuthorTriggerChannel + # and contains a serialized version of the author + # object in the payload attribute. + insert_notification = Notification.from_channel( + channel=AuthorTriggerChannel).get() + deserialized = AuthorTriggerChannel.deserialize( + insert_notification.payload) + + assert deserialized['new'].name == author.name + assert deserialized['new'].alternative_name == author.alternative_name + assert deserialized['new'].id == author.pk + assert deserialized['new'].user == author.user + assert deserialized['new'].profile_picture == author.profile_picture + + +@pytest.mark.django_db(transaction=True) +def test_deserialize_edit_payload(): + media = Media.objects.create( + name='avatar.jpg', + content_type='image/png', + size=15000, + ) + assert 1 == Notification.objects.all().count() + insert_notification = Notification.from_channel( + channel=MediaTriggerChannel).last() + + deserialized = MediaTriggerChannel.deserialize( + insert_notification.payload) + + assert media.name == deserialized['new'].name + assert media.pk == deserialized['new'].id + assert media.size == deserialized['new'].size + + media.name = 'avatar_2.jpg' + media.save() + + assert 2 == Notification.objects.all().count() + edit_notification = Notification.from_channel( + channel=MediaTriggerChannel).last() + + deserialized = MediaTriggerChannel.deserialize( + edit_notification.payload) + + assert deserialized['new'].name == media.name + assert deserialized['new'].id == media.pk + assert deserialized['new'].size == media.size + + +@pytest.mark.django_db(transaction=True) +def test_deserialize_delete_payload(): + user = User.objects.create(username='Billy') + author = Author.objects.create(name='Billy', user=user) + + post = Post.objects.create( + author=author, + content='my post', + date=timezone.now(), + ) + original_id = post.pk + + # When we delete a post, a notification is sent via + # PostTriggerChannel + post.delete() + delete_notification = Notification.from_channel( + channel=PostTriggerChannel).get() + deserialized = PostTriggerChannel.deserialize( + delete_notification.payload) + + assert deserialized['old'].author == post.author + assert deserialized['old'].date.date() == post.date.date() + assert deserialized['old'].date.time() == post.date.time() + assert deserialized['old'].id == original_id + assert deserialized['new'] is None