From 3ae9fc17374d399578c3ef36604c3fbebf7cc03b Mon Sep 17 00:00:00 2001 From: Viicos <65306057+Viicos@users.noreply.github.com> Date: Tue, 12 Dec 2023 12:57:09 +0100 Subject: [PATCH] =?UTF-8?q?[#351]=20=E2=9C=A8=20Use=20JSON=20Merge=20Patch?= =?UTF-8?q?=20when=20doing=20a=20partial=20update=20on=20records?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/objects/api/serializers.py | 6 +++++ src/objects/api/utils.py | 21 +++++++++++++++++- src/objects/tests/test_merge_patch.py | 29 +++++++++++++++++++++++++ src/objects/tests/v1/test_object_api.py | 9 ++++++-- src/objects/tests/v2/test_object_api.py | 9 ++++++-- 5 files changed, 69 insertions(+), 5 deletions(-) create mode 100644 src/objects/tests/test_merge_patch.py diff --git a/src/objects/api/serializers.py b/src/objects/api/serializers.py index 21c00cb7..cbf908eb 100644 --- a/src/objects/api/serializers.py +++ b/src/objects/api/serializers.py @@ -10,6 +10,7 @@ from .fields import ObjectSlugRelatedField, ObjectTypeField, ObjectUrlField from .validators import GeometryValidator, IsImmutableValidator, JsonSchemaValidator +from .utils import merge_patch class ObjectRecordSerializer(serializers.ModelSerializer): @@ -129,6 +130,11 @@ def update(self, instance, validated_data): # in case of PATCH if "version" not in validated_data: validated_data["version"] = instance.version + if "data" in validated_data: + # Apply JSON Merge Patch for record data + validated_data["data"] = merge_patch( + instance.data, validated_data["data"] + ) record = super().create(validated_data) return record diff --git a/src/objects/api/utils.py b/src/objects/api/utils.py index 84c2b86f..bd337a3b 100644 --- a/src/objects/api/utils.py +++ b/src/objects/api/utils.py @@ -1,5 +1,5 @@ from datetime import date -from typing import Union +from typing import Any, Dict, Union from djchoices import DjangoChoices @@ -43,3 +43,22 @@ def display_choice_values_for_help_text(choices: DjangoChoices) -> str: items.append(item) return "\n".join(items) + + +def merge_patch(target: Any, patch: Any) -> Dict[str, Any]: + """An implementation of https://datatracker.ietf.org/doc/html/rfc7396 - JSON Merge Patch""" + if not isinstance(patch, dict): + return patch + + if not isinstance(target, dict): + # Ignore the contents and set it to an empty dict + target = {} + for k, v in patch.items(): + if v is None: + if k in target: + # remove the key/value pair from target + del target[k] + else: + target[k] = merge_patch(target.get(k), v) + + return target diff --git a/src/objects/tests/test_merge_patch.py b/src/objects/tests/test_merge_patch.py new file mode 100644 index 00000000..47e07cea --- /dev/null +++ b/src/objects/tests/test_merge_patch.py @@ -0,0 +1,29 @@ +from unittest import TestCase + +from objects.api.utils import merge_patch + + +class MergePatchTests(TestCase): + def test_merge_patch(self): + + test_data = [ + ({"a": "b"}, {"a": "c"}, {"a": "c"}), + ({"a": "b"}, {"b": "c"}, {"a": "b", "b": "c"}), + ({"a": "b"}, {"a": None}, {}), + ({"a": "b", "b": "c"}, {"a": None}, {"b": "c"}), + ({"a": ["b"]}, {"a": "c"}, {"a": "c"}), + ({"a": "c"}, {"a": ["b"]}, {"a": ["b"]}), + ({"a": {"b": "c"}}, {"a": {"b": "d", "c": None}}, {"a": {"b": "d"}}), + ({"a": [{"b": "c"}]}, {"a": [1]}, {"a": [1]}), + (["a", "b"], ["c", "d"], ["c", "d"]), + ({"a": "b"}, ["c"], ["c"]), + ({"a": "foo"}, None, None), + ({"a": "foo"}, "bar", "bar"), + ({"e": None}, {"a": 1}, {"e": None, "a": 1}), + ([1, 2], {"a": "b", "c": None}, {"a": "b"}), + ({}, {"a": {"bb": {"ccc": None}}}, {"a": {"bb": {}}}), + ] + + for target, patch, expected in test_data: + with self.subTest(): + self.assertEqual(merge_patch(target, patch), expected) diff --git a/src/objects/tests/v1/test_object_api.py b/src/objects/tests/v1/test_object_api.py index fed267a1..52667ad7 100644 --- a/src/objects/tests/v1/test_object_api.py +++ b/src/objects/tests/v1/test_object_api.py @@ -205,7 +205,10 @@ def test_patch_object_record(self, m): ) initial_record = ObjectRecordFactory.create( - version=1, object__object_type=self.object_type, start_at=date.today() + version=1, + object__object_type=self.object_type, + start_at=date.today(), + data={"name": "Name", "diameter": 20}, ) object = initial_record.object @@ -229,8 +232,10 @@ def test_patch_object_record(self, m): current_record = object.current_record self.assertEqual(current_record.version, initial_record.version) + # The actual behavior of the data merging is in test_merge_patch.py: self.assertEqual( - current_record.data, {"plantDate": "2020-04-12", "diameter": 30} + current_record.data, + {"plantDate": "2020-04-12", "diameter": 30, "name": "Name"}, ) self.assertEqual(current_record.start_at, date(2020, 1, 1)) self.assertEqual(current_record.registration_at, date(2020, 8, 8)) diff --git a/src/objects/tests/v2/test_object_api.py b/src/objects/tests/v2/test_object_api.py index 12b5a43a..b688048c 100644 --- a/src/objects/tests/v2/test_object_api.py +++ b/src/objects/tests/v2/test_object_api.py @@ -227,7 +227,10 @@ def test_patch_object_record(self, m): ) initial_record = ObjectRecordFactory.create( - version=1, object__object_type=self.object_type, start_at=date.today() + version=1, + object__object_type=self.object_type, + start_at=date.today(), + data={"name": "Name", "diameter": 20}, ) object = initial_record.object @@ -251,8 +254,10 @@ def test_patch_object_record(self, m): current_record = object.current_record self.assertEqual(current_record.version, initial_record.version) + # The actual behavior of the data merging is in test_merge_patch.py: self.assertEqual( - current_record.data, {"plantDate": "2020-04-12", "diameter": 30} + current_record.data, + {"plantDate": "2020-04-12", "diameter": 30, "name": "Name"}, ) self.assertEqual(current_record.start_at, date(2020, 1, 1)) self.assertEqual(current_record.registration_at, date(2020, 8, 8))