From ed751180b3df486fea706d0e90781c7dc4c9e177 Mon Sep 17 00:00:00 2001 From: Jonathan Mortensen <56177725+jmo-qap@users.noreply.github.com> Date: Wed, 3 Mar 2021 03:15:39 -0800 Subject: [PATCH] support multi db atomic_requests (#7739) --- rest_framework/views.py | 8 ++++---- tests/conftest.py | 4 ++++ tests/test_atomic_requests.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/rest_framework/views.py b/rest_framework/views.py index d1b5e4ed90..5b06220691 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,7 +3,7 @@ """ from django.conf import settings from django.core.exceptions import PermissionDenied -from django.db import connection, models, transaction +from django.db import connections, models from django.http import Http404 from django.http.response import HttpResponseBase from django.utils.cache import cc_delim_re, patch_vary_headers @@ -63,9 +63,9 @@ def get_view_description(view, html=False): def set_rollback(): - atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False) - if atomic_requests and connection.in_atomic_block: - transaction.set_rollback(True) + for db in connections.all(): + if db.settings_dict['ATOMIC_REQUESTS'] and db.in_atomic_block: + db.set_rollback(True) def exception_handler(exc, context): diff --git a/tests/conftest.py b/tests/conftest.py index ac29e4a429..cc32cc6373 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,6 +24,10 @@ def pytest_configure(config): 'default': { 'ENGINE': 'django.db.backends.sqlite3', 'NAME': ':memory:' + }, + 'secondary': { + 'ENGINE': 'django.db.backends.sqlite3', + 'NAME': ':memory:' } }, SITE_ID=1, diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index 15b41e02f4..beda5cba19 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -130,6 +130,41 @@ def test_api_exception_rollback_transaction(self): assert BasicModel.objects.count() == 0 +@unittest.skipUnless( + connection.features.uses_savepoints, + "'atomic' requires transactions and savepoints." +) +class MultiDBTransactionAPIExceptionTests(TestCase): + databases = '__all__' + + def setUp(self): + self.view = APIExceptionView.as_view() + connections.databases['default']['ATOMIC_REQUESTS'] = True + connections.databases['secondary']['ATOMIC_REQUESTS'] = True + + def tearDown(self): + connections.databases['default']['ATOMIC_REQUESTS'] = False + connections.databases['secondary']['ATOMIC_REQUESTS'] = False + + def test_api_exception_rollback_transaction(self): + """ + Transaction is rollbacked by our transaction atomic block. + """ + request = factory.post('/') + num_queries = 4 if connection.features.can_release_savepoints else 3 + with self.assertNumQueries(num_queries): + # 1 - begin savepoint + # 2 - insert + # 3 - rollback savepoint + # 4 - release savepoint + with transaction.atomic(), transaction.atomic(using='secondary'): + response = self.view(request) + assert transaction.get_rollback() + assert transaction.get_rollback(using='secondary') + assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR + assert BasicModel.objects.count() == 0 + + @unittest.skipUnless( connection.features.uses_savepoints, "'atomic' requires transactions and savepoints."