diff --git a/rest_framework/mixins.py b/rest_framework/mixins.py index 6ac6366c75..ff001d681b 100644 --- a/rest_framework/mixins.py +++ b/rest_framework/mixins.py @@ -69,13 +69,14 @@ def update(self, request, *args, **kwargs): serializer.is_valid(raise_exception=True) self.perform_update(serializer) - queryset = self.filter_queryset(self.get_queryset()) - if queryset._prefetch_related_lookups: + if hasattr(instance, '_prefetched_objects_cache'): # If 'prefetch_related' has been applied to a queryset, we need to - # forcibly invalidate the prefetch cache on the instance, - # and then re-prefetch related objects + # forcibly invalidate the prefetch cache on the instance instance._prefetched_objects_cache = {} - prefetch_related_objects([instance], *queryset._prefetch_related_lookups) + queryset = self.filter_queryset(self.get_queryset()) + if getattr(queryset, '_prefetch_related_lookups', None): + # And then re-prefetch related objects + prefetch_related_objects([instance], *queryset._prefetch_related_lookups) return Response(serializer.data) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index 8e7bcf4ace..a8f309f822 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -35,6 +35,13 @@ class UserUpdateWithoutPrefetchRelated(generics.UpdateAPIView): serializer_class = UserSerializer +class UserRetrieveWithoutQuerySet(generics.RetrieveUpdateAPIView): + serializer_class = UserSerializer + + def get_object(self): + return User.objects.get(pk=self.kwargs['pk']) + + class TestPrefetchRelatedUpdates(TestCase): def setUp(self): self.user = User.objects.create(username='tom', email='tom@example.com') @@ -90,3 +97,11 @@ def test_db_query_count(self): ) with self.assertNumQueries(16): UserUpdateWithoutPrefetchRelated.as_view()(request, pk=self.user.pk) + + def test_can_update_without_queryset(self): + request = factory.patch('/', {'username': 'new'}) + response = UserRetrieveWithoutQuerySet.as_view()(request, pk=self.user.pk) + assert response.data['id'] == self.user.id + assert response.data['username'] == 'new' + self.user.refresh_from_db() + assert self.user.username == 'new'