Skip to content

Commit

Permalink
Support atomic transaction views in multiple database connections
Browse files Browse the repository at this point in the history
  • Loading branch information
aradbar committed Jan 17, 2023
1 parent 0618fa8 commit ac3e849
Showing 1 changed file with 80 additions and 21 deletions.
101 changes: 80 additions & 21 deletions tests/test_atomic_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,35 @@


class BasicView(APIView):
database = 'default'

def get_queryset(self):
return BasicModel.objects.using(self.database).all()

def post(self, request, *args, **kwargs):
BasicModel.objects.create()
self.get_queryset().create()
return Response({'method': 'GET'})


class ErrorView(APIView):
class ErrorView(BasicView):
def post(self, request, *args, **kwargs):
BasicModel.objects.create()
self.get_queryset().create()
raise Exception


class APIExceptionView(APIView):
class APIExceptionView(BasicView):
def post(self, request, *args, **kwargs):
BasicModel.objects.create()
self.get_queryset().create()
raise APIException


class NonAtomicAPIExceptionView(APIView):
class NonAtomicAPIExceptionView(BasicView):
@transaction.non_atomic_requests
def dispatch(self, *args, **kwargs):
return super().dispatch(*args, **kwargs)

def get(self, request, *args, **kwargs):
BasicModel.objects.all()
self.get_queryset()
raise Http404


Expand All @@ -53,34 +58,52 @@ def get(self, request, *args, **kwargs):
"'atomic' requires transactions and savepoints."
)
class DBTransactionTests(TestCase):
databases = '__all__'

def setUp(self):
self.view = BasicView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
self.view = BasicView
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True

def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False

def test_no_exception_commit_transaction(self):
request = factory.post('/')

with self.assertNumQueries(1):
response = self.view(request)
response = self.view.as_view()(request)
assert not transaction.get_rollback()
assert response.status_code == status.HTTP_200_OK
assert BasicModel.objects.count() == 1

def test_no_exception_commit_transaction_spare_connection(self):
request = factory.post('/')

with self.assertNumQueries(1, using='spare'):
view = self.view.as_view(database='spare')
response = view(request)
assert not transaction.get_rollback(using='spare')
assert response.status_code == status.HTTP_200_OK
assert BasicModel.objects.using('spare').count() == 1


@unittest.skipUnless(
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
)
class DBTransactionErrorTests(TestCase):
databases = '__all__'

def setUp(self):
self.view = ErrorView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
self.view = ErrorView
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True

def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False

def test_generic_exception_delegate_transaction_management(self):
"""
Expand All @@ -95,22 +118,37 @@ def test_generic_exception_delegate_transaction_management(self):
# 2 - insert
# 3 - release savepoint
with transaction.atomic():
self.assertRaises(Exception, self.view, request)
self.assertRaises(Exception, self.view.as_view(), request)
assert not transaction.get_rollback()
assert BasicModel.objects.count() == 1

def test_generic_exception_delegate_transaction_management_spare_connections(self):
request = factory.post('/')
with self.assertNumQueries(3, using='spare'):
# 1 - begin savepoint
# 2 - insert
# 3 - release savepoint
with transaction.atomic(using='spare'):
self.assertRaises(Exception, self.view.as_view(database='spare'), request)
assert not transaction.get_rollback(using='spare')
assert BasicModel.objects.using('spare').count() == 1


@unittest.skipUnless(
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
)
class DBTransactionAPIExceptionTests(TestCase):
databases = '__all__'

def setUp(self):
self.view = APIExceptionView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
self.view = APIExceptionView
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True

def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False

def test_api_exception_rollback_transaction(self):
"""
Expand All @@ -124,11 +162,28 @@ def test_api_exception_rollback_transaction(self):
# 3 - rollback savepoint
# 4 - release savepoint
with transaction.atomic():
response = self.view(request)
response = self.view.as_view()(request)
assert transaction.get_rollback()
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert BasicModel.objects.count() == 0

def test_api_exception_rollback_transaction_spare_connection(self):
"""
Transaction is rollbacked by our transaction atomic block.
"""
request = factory.post('/')
num_queries = 4 if connections['spare'].features.can_release_savepoints else 3
with self.assertNumQueries(num_queries, using='spare'):
# 1 - begin savepoint
# 2 - insert
# 3 - rollback savepoint
# 4 - release savepoint
with transaction.atomic(using='spare'):
response = self.view.as_view(database='spare')(request)
assert transaction.get_rollback(using='spare')
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
assert BasicModel.objects.using('spare').count() == 0


@unittest.skipUnless(
connection.features.uses_savepoints,
Expand Down Expand Up @@ -171,11 +226,15 @@ def test_api_exception_rollback_transaction(self):
)
@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
databases = '__all__'

def setUp(self):
connections.databases['default']['ATOMIC_REQUESTS'] = True
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = True

def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
for database in connections.databases:
connections.databases[database]['ATOMIC_REQUESTS'] = False

def test_api_exception_rollback_transaction_non_atomic_view(self):
response = self.client.get('/')
Expand Down

0 comments on commit ac3e849

Please sign in to comment.