From af2c4a6297a08cc8f62c4bbc5f36cc7b51fefe32 Mon Sep 17 00:00:00 2001 From: Yuekui Li Date: Fri, 18 Jun 2021 22:33:32 -0700 Subject: [PATCH 1/4] Re-prefetch related objects after updating --- rest_framework/generics.py | 24 +++++++++++++++++ rest_framework/mixins.py | 6 ++++- tests/test_prefetch_related.py | 47 +++++++++++++++------------------- 3 files changed, 50 insertions(+), 27 deletions(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 55cfafda44..17290c5ebc 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -1,6 +1,7 @@ """ Generic views that provide commonly needed behaviour. """ +from typing import Iterable from django.core.exceptions import ValidationError from django.db.models.query import QuerySet from django.http import Http404 @@ -45,6 +46,8 @@ class GenericAPIView(views.APIView): # The style to use for queryset pagination. pagination_class = api_settings.DEFAULT_PAGINATION_CLASS + prefetch_related = [] + def get_queryset(self): """ Get the list of items for this view. @@ -68,9 +71,30 @@ def get_queryset(self): queryset = self.queryset if isinstance(queryset, QuerySet): + # Prefetch related objects + if self.get_prefetch_related(): + queryset = queryset.prefetch_related(*self.get_prefetch_related()) # Ensure queryset is re-evaluated on each request. queryset = queryset.all() return queryset + + def get_prefetch_related(self): + """ + Get the list of prefetch related objects for self.queryset or instance. + This must be an iterable. + Defaults to using `self.prefetch_related`. + + You may want to override this if you need to provide prefetched objects + depending on the incoming request. + + (Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`) + """ + assert isinstance(self.prefetch_related, Iterable), ( + "'%s' should either include an iterable `prefetch_related` attribute, " + "or override the `get_prefetch_related()` method." + % self.__class__.__name__ + ) + return self.prefetch_related def get_object(self): """ diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 7fa8947cb9..98127757fa 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -4,6 +4,8 @@ We don't bind behaviour to http method handlers yet, which allows mixin classes to be composed in interesting ways. """ +from django.db.models.query import prefetch_related_objects + from rest_framework import status from rest_framework.response import Response from rest_framework.settings import api_settings @@ -69,8 +71,10 @@ def update(self, request, *args, **kwargs): if getattr(instance, '_prefetched_objects_cache', None): # If 'prefetch_related' has been applied to a queryset, we need to - # forcibly invalidate the prefetch cache on the instance. + # forcibly invalidate the prefetch cache on the instance, + # and then re-prefetch related objects instance._prefetched_objects_cache = {} + prefetch_related_objects([instance], *self.get_prefetch_related()) return Response(serializer.data) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index b07087c978..2f0064bf13 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -14,8 +14,9 @@ class Meta: class UserUpdate(generics.UpdateAPIView): - queryset = User.objects.exclude(username='exclude').prefetch_related('groups') + queryset = User.objects.exclude(username='exclude') serializer_class = UserSerializer + prefetch_related = ['groups'] class TestPrefetchRelatedUpdates(TestCase): @@ -23,36 +24,30 @@ def setUp(self): self.user = User.objects.create(username='tom', email='tom@example.com') self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] self.user.groups.set(self.groups) - - def test_prefetch_related_updates(self): - view = UserUpdate.as_view() - pk = self.user.pk - groups_pk = self.groups[0].pk - request = factory.put('/', {'username': 'new', 'groups': [groups_pk]}, format='json') - response = view(request, pk=pk) - assert User.objects.get(pk=pk).groups.count() == 1 - expected = { - 'id': pk, + self.expected = { + 'id': self.user.pk, 'username': 'new', 'groups': [1], - 'email': 'tom@example.com' + 'email': 'tom@example.com', } - assert response.data == expected + self.view = UserUpdate.as_view() + + def test_prefetch_related_updates(self): + request = factory.put( + '/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json' + ) + response = self.view(request, pk=self.user.pk) + assert User.objects.get(pk=self.user.pk).groups.count() == 1 + assert response.data == self.expected def test_prefetch_related_excluding_instance_from_original_queryset(self): """ Regression test for https://github.com/encode/django-rest-framework/issues/4661 """ - view = UserUpdate.as_view() - pk = self.user.pk - groups_pk = self.groups[0].pk - request = factory.put('/', {'username': 'exclude', 'groups': [groups_pk]}, format='json') - response = view(request, pk=pk) - assert User.objects.get(pk=pk).groups.count() == 1 - expected = { - 'id': pk, - 'username': 'exclude', - 'groups': [1], - 'email': 'tom@example.com' - } - assert response.data == expected + request = factory.put( + '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json' + ) + response = self.view(request, pk=self.user.pk) + assert User.objects.get(pk=self.user.pk).groups.count() == 1 + self.expected['username'] = 'exclude' + assert response.data == self.expected From 7f24ef2af6705d95686305938688cf659fd7a916 Mon Sep 17 00:00:00 2001 From: Yuekui Li Date: Fri, 18 Jun 2021 22:50:49 -0700 Subject: [PATCH 2/4] Fix flake8 format --- rest_framework/generics.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index 17290c5ebc..e42ca529cb 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -2,6 +2,7 @@ Generic views that provide commonly needed behaviour. """ from typing import Iterable + from django.core.exceptions import ValidationError from django.db.models.query import QuerySet from django.http import Http404 @@ -77,7 +78,7 @@ def get_queryset(self): # Ensure queryset is re-evaluated on each request. queryset = queryset.all() return queryset - + def get_prefetch_related(self): """ Get the list of prefetch related objects for self.queryset or instance. From 2da4374ee10a1544cad73d779e78f40b513cada1 Mon Sep 17 00:00:00 2001 From: Yuekui Li Date: Wed, 7 Jul 2021 18:53:46 -0700 Subject: [PATCH 3/4] Use _prefetch_related_lookups and refine test cases --- rest_framework/generics.py | 25 ------------------- rest_framework/mixins.py | 3 ++- tests/test_prefetch_related.py | 44 ++++++++++++++++++++++++++++------ 3 files changed, 39 insertions(+), 33 deletions(-) diff --git a/rest_framework/generics.py b/rest_framework/generics.py index e42ca529cb..55cfafda44 100644 --- a/rest_framework/generics.py +++ b/rest_framework/generics.py @@ -1,8 +1,6 @@ """ Generic views that provide commonly needed behaviour. """ -from typing import Iterable - from django.core.exceptions import ValidationError from django.db.models.query import QuerySet from django.http import Http404 @@ -47,8 +45,6 @@ class GenericAPIView(views.APIView): # The style to use for queryset pagination. pagination_class = api_settings.DEFAULT_PAGINATION_CLASS - prefetch_related = [] - def get_queryset(self): """ Get the list of items for this view. @@ -72,31 +68,10 @@ def get_queryset(self): queryset = self.queryset if isinstance(queryset, QuerySet): - # Prefetch related objects - if self.get_prefetch_related(): - queryset = queryset.prefetch_related(*self.get_prefetch_related()) # Ensure queryset is re-evaluated on each request. queryset = queryset.all() return queryset - def get_prefetch_related(self): - """ - Get the list of prefetch related objects for self.queryset or instance. - This must be an iterable. - Defaults to using `self.prefetch_related`. - - You may want to override this if you need to provide prefetched objects - depending on the incoming request. - - (Eg. `['toppings', Prefetch('restaurants', queryset=Restaurant.objects.select_related('best_pizza'))]`) - """ - assert isinstance(self.prefetch_related, Iterable), ( - "'%s' should either include an iterable `prefetch_related` attribute, " - "or override the `get_prefetch_related()` method." - % self.__class__.__name__ - ) - return self.prefetch_related - def get_object(self): """ Returns the object the view is displaying. diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 98127757fa..6031b06ad0 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -74,7 +74,8 @@ def update(self, request, *args, **kwargs): # forcibly invalidate the prefetch cache on the instance, # and then re-prefetch related objects instance._prefetched_objects_cache = {} - prefetch_related_objects([instance], *self.get_prefetch_related()) + queryset = self.filter_queryset(self.get_queryset()) + prefetch_related_objects([instance], *queryset._prefetch_related_lookups) return Response(serializer.data) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 2f0064bf13..d380666903 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -8,36 +8,52 @@ class UserSerializer(serializers.ModelSerializer): + permissions = serializers.SerializerMethodField() + + def get_permissions(self, obj): + ret = [] + for g in obj.groups.all(): + ret.extend([p.pk for p in g.permissions.all()]) + return ret + class Meta: model = User - fields = ('id', 'username', 'email', 'groups') + fields = ('id', 'username', 'email', 'groups', 'permissions') class UserUpdate(generics.UpdateAPIView): + queryset = User.objects.exclude(username='exclude').prefetch_related('groups__permissions') + serializer_class = UserSerializer + + +class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView): queryset = User.objects.exclude(username='exclude') serializer_class = UserSerializer - prefetch_related = ['groups'] class TestPrefetchRelatedUpdates(TestCase): def setUp(self): self.user = User.objects.create(username='tom', email='tom@example.com') - self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] + self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)] self.user.groups.set(self.groups) self.expected = { 'id': self.user.pk, - 'username': 'new', - 'groups': [1], + 'username': 'tom', + 'groups': [group.pk for group in self.groups], 'email': 'tom@example.com', + 'permissions': [], } self.view = UserUpdate.as_view() def test_prefetch_related_updates(self): + self.groups.append(Group.objects.create(name='c')) request = factory.put( - '/', {'username': 'new', 'groups': [self.groups[0].pk]}, format='json' + '/', {'username': 'new', 'groups': [group.pk for group in self.groups]}, format='json' ) + self.expected['username'] = 'new' + self.expected['groups'] = [group.pk for group in self.groups] response = self.view(request, pk=self.user.pk) - assert User.objects.get(pk=self.user.pk).groups.count() == 1 + assert User.objects.get(pk=self.user.pk).groups.count() == 11 assert response.data == self.expected def test_prefetch_related_excluding_instance_from_original_queryset(self): @@ -50,4 +66,18 @@ def test_prefetch_related_excluding_instance_from_original_queryset(self): response = self.view(request, pk=self.user.pk) assert User.objects.get(pk=self.user.pk).groups.count() == 1 self.expected['username'] = 'exclude' + self.expected['groups'] = [self.groups[0].pk] assert response.data == self.expected + + def test_db_query_count(self): + request = factory.put( + '/', {'username': 'new'}, format='json' + ) + with self.assertNumQueries(7): + self.view(request, pk=self.user.pk) + + request = factory.put( + '/', {'username': 'new2'}, format='json' + ) + with self.assertNumQueries(15): + UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk) From f322c0415941e410b3cf1c558e91605691d1416b Mon Sep 17 00:00:00 2001 From: Yuekui Li Date: Wed, 23 Nov 2022 19:36:18 -0800 Subject: [PATCH 4/4] Add more test cases and refine prefetch checking --- rest_framework/mixins.py | 4 ++-- tests/test_prefetch_related.py | 21 +++++++++++++++------ 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 6031b06ad0..6ac6366c75 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -69,12 +69,12 @@ def update(self, request, *args, **kwargs): serializer.is_valid(raise_exception=True) self.perform_update(serializer) - if getattr(instance, '_prefetched_objects_cache', None): + queryset = self.filter_queryset(self.get_queryset()) + if queryset._prefetch_related_lookups: # If 'prefetch_related' has been applied to a queryset, we need to # forcibly invalidate the prefetch cache on the instance, # and then re-prefetch related objects instance._prefetched_objects_cache = {} - queryset = self.filter_queryset(self.get_queryset()) prefetch_related_objects([instance], *queryset._prefetch_related_lookups) return Response(serializer.data) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index d380666903..8e7bcf4ace 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -1,4 +1,5 @@ from django.contrib.auth.models import Group, User +from django.db.models.query import Prefetch from django.test import TestCase from rest_framework import generics, serializers @@ -21,8 +22,11 @@ class Meta: fields = ('id', 'username', 'email', 'groups', 'permissions') -class UserUpdate(generics.UpdateAPIView): - queryset = User.objects.exclude(username='exclude').prefetch_related('groups__permissions') +class UserRetrieveUpdate(generics.RetrieveUpdateAPIView): + queryset = User.objects.exclude(username='exclude').prefetch_related( + Prefetch('groups', queryset=Group.objects.exclude(name='exclude')), + 'groups__permissions', + ) serializer_class = UserSerializer @@ -36,6 +40,7 @@ def setUp(self): self.user = User.objects.create(username='tom', email='tom@example.com') self.groups = [Group.objects.create(name=f'group {i}') for i in range(10)] self.user.groups.set(self.groups) + self.user.groups.add(Group.objects.create(name='exclude')) self.expected = { 'id': self.user.pk, 'username': 'tom', @@ -43,7 +48,7 @@ def setUp(self): 'email': 'tom@example.com', 'permissions': [], } - self.view = UserUpdate.as_view() + self.view = UserRetrieveUpdate.as_view() def test_prefetch_related_updates(self): self.groups.append(Group.objects.create(name='c')) @@ -53,7 +58,11 @@ def test_prefetch_related_updates(self): self.expected['username'] = 'new' self.expected['groups'] = [group.pk for group in self.groups] response = self.view(request, pk=self.user.pk) - assert User.objects.get(pk=self.user.pk).groups.count() == 11 + assert User.objects.get(pk=self.user.pk).groups.count() == 12 + assert response.data == self.expected + # Update and fetch should get same result + request = factory.get('/') + response = self.view(request, pk=self.user.pk) assert response.data == self.expected def test_prefetch_related_excluding_instance_from_original_queryset(self): @@ -64,7 +73,7 @@ def test_prefetch_related_excluding_instance_from_original_queryset(self): '/', {'username': 'exclude', 'groups': [self.groups[0].pk]}, format='json' ) response = self.view(request, pk=self.user.pk) - assert User.objects.get(pk=self.user.pk).groups.count() == 1 + assert User.objects.get(pk=self.user.pk).groups.count() == 2 self.expected['username'] = 'exclude' self.expected['groups'] = [self.groups[0].pk] assert response.data == self.expected @@ -79,5 +88,5 @@ def test_db_query_count(self): request = factory.put( '/', {'username': 'new2'}, format='json' ) - with self.assertNumQueries(15): + with self.assertNumQueries(16): UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk)