diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index 462bb49be..17233b21a 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -2,7 +2,9 @@ from __future__ import with_statement +from contextlib import contextmanager import os +import sys import warnings import pytest @@ -13,7 +15,8 @@ from .django_compat import is_django_unittest from .lazy_django import get_django_version, skip_if_no_django -__all__ = ['_django_db_setup', 'db', 'transactional_db', 'admin_user', +__all__ = ['_django_db_setup', 'db', 'transactional_db', 'shared_db_wrapper', + 'admin_user', 'django_user_model', 'django_username_field', 'client', 'admin_client', 'rf', 'settings', 'live_server', '_live_server_helper'] @@ -195,6 +198,62 @@ def transactional_db(request, _django_db_setup, _django_cursor_wrapper): return _django_db_fixture_helper(True, request, _django_cursor_wrapper) +@pytest.fixture(scope='session') +def shared_db_wrapper(_django_db_setup, _django_cursor_wrapper): + """Wrapper for common database initialization code. + + This fixture provides a context manager that let's you access the database + from a transaction spanning multiple tests. + """ + from django.db import connection, transaction + + if get_django_version() < (1, 6): + raise Exception('shared_db_wrapper is only supported on Django >= 1.6.') + + class DummyException(Exception): + """Dummy for use with Atomic.__exit__.""" + + @contextmanager + def wrapper(request): + # We need to take the request + # to bind finalization to the place where this is used + if 'transactional_db' in request.funcargnames: + raise Exception( + 'shared_db_wrapper cannot be used with `transactional_db`.') + + with _django_cursor_wrapper: + if not connection.features.supports_transactions: + raise Exception( + "shared_db_wrapper cannot be used when " + "the database doesn't support transactions.") + + exc_type, exc_value, traceback = DummyException, DummyException(), None + # Use atomic instead of calling .savepoint* directly. + # This way works for both top-level transactions and "subtransactions". + atomic = transaction.atomic() + + def finalize(): + # Only run __exit__ if there was no error running the wrapped function. + # Otherwise we've run it already. + if exc_type == DummyException: + # dummy exception makes `atomic` rollback the savepoint + atomic.__exit__(exc_type, exc_value, traceback) + + try: + _django_cursor_wrapper.enable() + atomic.__enter__() + yield + except: + exc_type, exc_value, traceback = sys.exc_info() + atomic.__exit__(exc_type, exc_value, traceback) + raise + finally: + request.addfinalizer(finalize) + _django_cursor_wrapper.restore() + + return wrapper + + @pytest.fixture() def client(): """A Django test client instance.""" diff --git a/pytest_django/plugin.py b/pytest_django/plugin.py index 2ac672433..d19da43bc 100644 --- a/pytest_django/plugin.py +++ b/pytest_django/plugin.py @@ -17,14 +17,14 @@ from .django_compat import is_django_unittest from .fixtures import (_django_db_setup, _live_server_helper, admin_client, admin_user, client, db, django_user_model, - django_username_field, live_server, rf, settings, - transactional_db) + django_username_field, live_server, rf, shared_db_wrapper, + settings, transactional_db) from .lazy_django import django_settings_is_configured, skip_if_no_django # Silence linters for imported fixtures. (_django_db_setup, _live_server_helper, admin_client, admin_user, client, db, django_user_model, django_username_field, live_server, rf, settings, - transactional_db) + shared_db_wrapper, transactional_db) SETTINGS_MODULE_ENV = 'DJANGO_SETTINGS_MODULE' diff --git a/tests/test_database.py b/tests/test_database.py index 0a2449cd1..b0e9ee1ff 100644 --- a/tests/test_database.py +++ b/tests/test_database.py @@ -4,6 +4,7 @@ from django.db import connection, transaction from django.test.testcases import connections_support_transactions +from pytest_django.lazy_django import get_django_version from pytest_django_test.app.models import Item @@ -51,6 +52,29 @@ def test_noaccess_fixture(noaccess): pass +@pytest.mark.skipif(get_django_version() < (1, 6), + reason="shared_db_wrapper needs at least Django 1.6") +class TestSharedDbWrapper(object): + """Tests for sharing data created with share_db_wrapper, order matters.""" + @pytest.fixture(scope='class') + def shared_item(self, request, shared_db_wrapper): + with shared_db_wrapper(request): + return Item.objects.create(name='shared item') + + def test_preparing_data(self, shared_item): + type(self)._shared_item_pk = shared_item.pk + + def test_accessing_the_same_data(self, db, shared_item): + retrieved_item = Item.objects.get(name='shared item') + assert type(self)._shared_item_pk == retrieved_item.pk + + +@pytest.mark.skipif(get_django_version() < (1, 6), + reason="shared_db_wrapper needs at least Django 1.6") +def test_shared_db_wrapper_not_leaking(db): + assert not Item.objects.filter(name='shared item').exists() + + class TestDatabaseFixtures: """Tests for the db and transactional_db fixtures"""