Skip to content

Commit

Permalink
Merge pull request #34 from Opus10/deserialization
Browse files Browse the repository at this point in the history
Fix deserialization of trigger channels
  • Loading branch information
PaulGilmartin authored Apr 11, 2023
2 parents d7fa515 + 86e0507 commit 1e21e5d
Show file tree
Hide file tree
Showing 8 changed files with 324 additions and 141 deletions.
3 changes: 3 additions & 0 deletions docs/channels.rst
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ all trigger-based notifications:
Here the ``old`` and ``new`` parameters are the (unsaved) versions of what the
trigger invoking instance looked like before and after the trigger was invoked.
These objects are built by passing in the trigger notification payload through
Django's model `deserializers <https://docs.djangoproject.com/en/4.1/topics/serialization/>`__.

In this example, ``old`` would refer to the state of our ``Author`` object
pre-creation (and would hence be ``None``) and ``new`` would refer to a copy of
the newly created ``Author`` instance. This payload is inspired by the ``OLD``
Expand Down
Empty file modified manage.py
100644 → 100755
Empty file.
106 changes: 72 additions & 34 deletions pgpubsub/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from typing import Callable, Dict, Union, List

from django.apps import apps
from django.core import serializers
from django.db import models
from django.db.models import fields


registry = defaultdict(list)
Expand Down Expand Up @@ -134,36 +134,6 @@ def _deserialize_arg(cls, arg, arg_type):
return arg_type(arg)


class TriggerPayload:
def __init__(self, payload: Dict):
self._json_payload = payload
self._model = apps.get_model(
app_label=self._json_payload['app'],
model_name=self._json_payload['model'],
)
self._old_row_data = self._json_payload['old']
self._new_row_data = self._json_payload['new']

@property
def old(self):
if self._old_row_data:
return self._entity_from_json(self._model, self._old_row_data)

@property
def new(self):
if self._new_row_data:
return self._entity_from_json(self._model, self._new_row_data)

def _entity_from_json(self, model, model_payload):
data = {}
for k, v in model_payload.items():
if isinstance(model._meta.get_field(k), fields.DateTimeField):
data[k] = datetime.datetime.fromisoformat(v)
else:
data[k] = v
return model(**data)


@dataclass
class TriggerChannel(BaseChannel):

Expand All @@ -173,9 +143,77 @@ class TriggerChannel(BaseChannel):

@classmethod
def deserialize(cls, payload: Union[Dict, str]):
payload = super().deserialize(payload)
trigger_payload = TriggerPayload(payload)
return {'old': trigger_payload.old, 'new': trigger_payload.new}
payload_dict = super().deserialize(payload)
old_model_data = cls._build_model_serializer_data(
payload_dict, state='old')
new_model_data = cls._build_model_serializer_data(
payload_dict, state='new')

old_deserialized_objects = serializers.deserialize(
'json',
json.dumps(old_model_data),
ignorenonexistent=True,
)
new_deserialized_objects = serializers.deserialize(
'json',
json.dumps(new_model_data),
ignorenonexistent=True,
)

old = next(old_deserialized_objects, None)
if old is not None:
old = old.object
new = next(new_deserialized_objects, None)
if new is not None:
new = new.object
return {'old': old, 'new': new}

@classmethod
def _build_model_serializer_data(cls, payload: Dict, state: str):
"""Reformat serialized data into shape as expected
by the Django model deserializer.
"""
app = payload['app']
model_name = payload['model']
model_cls = apps.get_model(
app_label=payload['app'],
model_name=payload['model'],
)
fields = {
field.name: field for field in model_cls._meta.fields
}
db_fields = {
field.db_column: field for field in model_cls._meta.fields
}

original_state = payload[state]
new_state = {}
model_data = []
if payload[state] is not None:
for field in list(original_state):
# Triggers serialize the notification payload with
# respect to how the model fields look as columns
# in the database. We therefore need to take
# care to map xxx_id named columns to the corresponding
# xxx model field and also to account for model fields
# with alternative database column names as declared
# by the db_column attribute.
value = original_state.pop(field)
if field.endswith('_id'):
field = field.rsplit('_id')[0]
if field in fields:
new_state[field] = value
elif field in db_fields:
field = db_fields[field].name
new_state[field] = value

model_data.append(
{'fields': new_state,
'id': new_state['id'],
'model': f'{app}.{model_name}',
},
)
return model_data


def locate_channel(channel):
Expand Down
24 changes: 24 additions & 0 deletions pgpubsub/tests/migrations/0002_auto_20230411_0829.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Generated by Django 3.2.12 on 2023-04-11 08:29

from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):

dependencies = [
('tests', '0001_initial'),
]

operations = [
migrations.AddField(
model_name='author',
name='alternative_name',
field=models.TextField(db_column='other', null=True),
),
migrations.AlterField(
model_name='author',
name='profile_picture',
field=models.ForeignKey(db_column='picture', null=True, on_delete=django.db.models.deletion.PROTECT, to='tests.media'),
),
]
11 changes: 9 additions & 2 deletions pgpubsub/tests/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,18 @@ class Media(models.Model):


class Author(models.Model):
user = models.ForeignKey(User, on_delete=models.PROTECT, null=True)
user = models.ForeignKey(
User, on_delete=models.PROTECT, null=True)
name = models.TextField()
age = models.IntegerField(null=True)
active = models.BooleanField(default=True)
profile_picture = models.ForeignKey(Media, null=True, on_delete=models.PROTECT)
profile_picture = models.ForeignKey(
Media,
null=True,
on_delete=models.PROTECT,
db_column='picture',
)
alternative_name = models.TextField(db_column='other', null=True)


class Post(models.Model):
Expand Down
106 changes: 1 addition & 105 deletions pgpubsub/tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,123 +1,19 @@
from dataclasses import dataclass
import datetime
import json
from typing import Dict, List, Set, Tuple

from django.db.transaction import atomic
import pytest

from pgpubsub.channel import Channel, TriggerChannel
from pgpubsub.listen import listen_to_channels, process_notifications, listen
from pgpubsub.models import Notification
from pgpubsub.notify import process_stored_notifications
from pgpubsub.tests.channels import (
AuthorTriggerChannel,
MediaTriggerChannel,
PostTriggerChannel,
)
from pgpubsub.tests.listeners import post_reads_per_date_cache, scan_media
from pgpubsub.tests.listeners import post_reads_per_date_cache
from pgpubsub.tests.models import Author, Media, Post


def test_deserialize_1():
@dataclass
class MyChannel(Channel):
arg1: str
arg2: Dict[int, int]
default_arg1: float = 0.0

deserialized = _deserialize(MyChannel, arg1='1', arg2={1: 2}, default_arg1=3.4)
assert {'arg1': '1', 'arg2': {1: 2},
'default_arg1': 3.4} == deserialized


def test_deserialize_2():
@dataclass
class MyChannel(Channel):
arg1: Dict[str, bool]
default_arg1: bool = False
default_arg2: int = 0

deserialized = _deserialize(
MyChannel, arg1={'Paul': False}, default_arg1=True)
assert {'arg1': {'Paul': False},
'default_arg1': True,
'default_arg2': 0} == deserialized


def test_deserialize_3():
@dataclass
class MyChannel(Channel):
arg1: datetime.date
arg2: Dict[datetime.date, bool]
arg3: Dict[str, datetime.datetime]

deserialized = _deserialize(
MyChannel,
arg1=datetime.date(2021, 1, 1),
arg2={
datetime.date(2021, 1, 7): True,
datetime.date(2021, 1, 17): False,
},
arg3={'chosen_date': datetime.datetime(2021, 1, 1, 9, 30)},
)

assert {
'arg1': datetime.date(2021, 1, 1),
'arg2': {datetime.date(2021, 1, 7): True, datetime.date(2021, 1, 17): False},
'arg3': {'chosen_date': datetime.datetime(2021, 1, 1, 9, 30)},
} == deserialized


def test_deserialize_trigger_channel():
@dataclass
class MyChannel(TriggerChannel):
model: Post

some_datetime = datetime.datetime.utcnow()
post = Post(content='some-content', date=some_datetime)
deserialized = MyChannel.deserialize(
json.dumps(
{
'app': 'tests',
'model': 'Post',
'old': None,
'new': {'content': 'some-content', 'date': some_datetime.isoformat()},
}
)
)
assert deserialized['new'].date == some_datetime
assert deserialized['new'].content == post.content
assert deserialized['new'].id == post.id
assert deserialized['new'].rating == post.rating
assert deserialized['new'].author == post.author


def _deserialize(channel_cls, **kwargs):
serialized = channel_cls(**kwargs).serialize()
return channel_cls.deserialize(serialized)


def test_deserialize_4():
@dataclass
class MyChannel(Channel):
arg1: List[datetime.date]
arg2: Set[float]
arg3: Tuple[str]

deserialized = _deserialize(
MyChannel,
arg1=[datetime.date(2021, 1, 1), datetime.date(2021, 1, 2)],
arg2={1.0, 2.1},
arg3=('hello', 'world'),
)
assert {
'arg1': [datetime.date(2021, 1, 1), datetime.date(2021, 1, 2)],
'arg2': {1.0, 2.1},
'arg3': ('hello', 'world'),
} == deserialized


@pytest.fixture()
def pg_connection():
return listen_to_channels()
Expand Down
80 changes: 80 additions & 0 deletions pgpubsub/tests/test_deserialize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from dataclasses import dataclass
import datetime
from typing import Dict, List, Set, Tuple

from pgpubsub.channel import Channel


def test_deserialize_1():
@dataclass
class MyChannel(Channel):
arg1: str
arg2: Dict[int, int]
default_arg1: float = 0.0

deserialized = _deserialize(MyChannel, arg1='1', arg2={1: 2}, default_arg1=3.4)
assert {'arg1': '1', 'arg2': {1: 2},
'default_arg1': 3.4} == deserialized


def test_deserialize_2():
@dataclass
class MyChannel(Channel):
arg1: Dict[str, bool]
default_arg1: bool = False
default_arg2: int = 0

deserialized = _deserialize(
MyChannel, arg1={'Paul': False}, default_arg1=True)
assert {'arg1': {'Paul': False},
'default_arg1': True,
'default_arg2': 0} == deserialized


def test_deserialize_3():
@dataclass
class MyChannel(Channel):
arg1: datetime.date
arg2: Dict[datetime.date, bool]
arg3: Dict[str, datetime.datetime]

deserialized = _deserialize(
MyChannel,
arg1=datetime.date(2021, 1, 1),
arg2={
datetime.date(2021, 1, 7): True,
datetime.date(2021, 1, 17): False,
},
arg3={'chosen_date': datetime.datetime(2021, 1, 1, 9, 30)},
)

assert {
'arg1': datetime.date(2021, 1, 1),
'arg2': {datetime.date(2021, 1, 7): True, datetime.date(2021, 1, 17): False},
'arg3': {'chosen_date': datetime.datetime(2021, 1, 1, 9, 30)},
} == deserialized


def test_deserialize_4():
@dataclass
class MyChannel(Channel):
arg1: List[datetime.date]
arg2: Set[float]
arg3: Tuple[str]

deserialized = _deserialize(
MyChannel,
arg1=[datetime.date(2021, 1, 1), datetime.date(2021, 1, 2)],
arg2={1.0, 2.1},
arg3=('hello', 'world'),
)
assert {
'arg1': [datetime.date(2021, 1, 1), datetime.date(2021, 1, 2)],
'arg2': {1.0, 2.1},
'arg3': ('hello', 'world'),
} == deserialized


def _deserialize(channel_cls, **kwargs):
serialized = channel_cls(**kwargs).serialize()
return channel_cls.deserialize(serialized)
Loading

0 comments on commit 1e21e5d

Please sign in to comment.