Skip to content

Commit

Permalink
Refactor Tortoise.init() and test runner
Browse files Browse the repository at this point in the history
Does not re-create connections per test, so now tests pass when using an SQLite in-memory database
Can pass event loop to test initializer function
  • Loading branch information
grigi committed Oct 22, 2018
1 parent d136225 commit 40f4945
Show file tree
Hide file tree
Showing 9 changed files with 89 additions and 51 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
Changelog
=========

0.10.10
-------
- Refactor ``Tortoise.init()`` and test runner to not re-create connections per test, so now tests pass when using an SQLite in-memory database
- Can pass event loop to test initializer function: ``initializer(loop=loop)``

0.10.9
------
- Uses macros on SQLite driver to minimise syncronisation. ``aiosqlite>=0.7.0``
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ test: deps

testall: deps
coverage erase
TORTOISE_TEST_DB=sqlite:///tmp/test-\{\}.sqlite coverage run -p --concurrency=multiprocessing `which green`
TORTOISE_TEST_DB=sqlite://:memory: coverage run -p --concurrency=multiprocessing `which green`
TORTOISE_TEST_DB=postgres://postgres:@127.0.0.1:5432/test_\{\} coverage run -p --concurrency=multiprocessing `which green`
TORTOISE_TEST_DB="mysql://root:@127.0.0.1:3306/test_\{\}" coverage run -p --concurrency=multiprocessing `which green`
coverage combine
Expand Down
17 changes: 8 additions & 9 deletions tortoise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,10 @@ def _init_apps(cls, apps_config):

cls.apps[name] = models_map

cls._init_relations()

cls._build_initial_querysets()

@classmethod
def _get_config_from_config_file(cls, config_file):
_, extension = os.path.splitext(config_file)
Expand Down Expand Up @@ -280,7 +284,8 @@ async def init(
For any configuration error
"""
if cls._inited:
await cls._reset_connections()
await cls.close_connections()
await cls._reset_apps()
if int(bool(config) + bool(config_file) + bool(db_url)) != 1:
raise ConfigurationError(
'You should init either from "config", "config_file" or "db_url"')
Expand All @@ -307,10 +312,6 @@ async def init(

cls._init_apps(apps_config)

cls._init_relations()

cls._build_initial_querysets()

cls._inited = True

@classmethod
Expand All @@ -320,9 +321,7 @@ async def close_connections(cls):
cls._connections = {}

@classmethod
async def _reset_connections(cls):
await cls.close_connections()

async def _reset_apps(cls):
for app in cls.apps.values():
for model in app.values():
model._meta.default_connection = None
Expand Down Expand Up @@ -353,7 +352,7 @@ async def _drop_databases(cls) -> None:
await connection.close()
await connection.db_delete()
cls._connections = {}
await cls._reset_connections()
await cls._reset_apps()


def run_async(coro):
Expand Down
2 changes: 1 addition & 1 deletion tortoise/backends/asyncpg/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from functools import wraps
from typing import List, SupportsInt, Optional # noqa
from typing import List, Optional, SupportsInt # noqa

import asyncpg
from pypika import PostgreSQLQuery
Expand Down
2 changes: 1 addition & 1 deletion tortoise/backends/mysql/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from functools import wraps
from typing import List, SupportsInt, Optional # noqa
from typing import List, Optional, SupportsInt # noqa

import aiomysql
import pymysql
Expand Down
84 changes: 59 additions & 25 deletions tortoise/contrib/test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
import asyncio as _asyncio
import asyncio
import os as _os
from copy import deepcopy
from typing import List
from asyncio.selector_events import BaseSelectorEventLoop
from typing import List, Optional
from unittest import SkipTest, expectedFailure, skip, skipIf, skipUnless # noqa

from asynctest import TestCase as _TestCase
from asynctest import _fail_on
from asynctest.case import _Policy

from tortoise import Tortoise
from tortoise import ContextVar, Tortoise
from tortoise.backends.base.config_generator import generate_config as _generate_config
from tortoise.exceptions import DBConnectionError
from tortoise.transactions import start_transaction
from tortoise.transactions import current_transaction_map, start_transaction

__all__ = ('SimpleTestCase', 'IsolatedTestCase', 'TestCase', 'SkipTest', 'expectedFailure',
'skip', 'skipIf', 'skipUnless', 'initializer', 'finalizer')
_TORTOISE_TEST_DB = _os.environ.get('TORTOISE_TEST_DB', 'sqlite:///tmp/test-{}.sqlite')
_TORTOISE_TEST_DB = _os.environ.get('TORTOISE_TEST_DB', 'sqlite://:memory:')

expectedFailure.__doc__ = """
Mark test as expecting failiure.
Expand All @@ -23,8 +24,9 @@
"""

_CONFIG = {} # type: dict
_APPS = {} # type: dict
_CONNECTIONS = {} # type: dict
_SELECTOR = None # type: ignore
_LOOP = None # type: BaseSelectorEventLoop


def getDBConfig(app_label: str, modules: List[str]) -> dict:
Expand Down Expand Up @@ -52,33 +54,46 @@ async def _init_db(config):
await Tortoise.generate_schemas()


def initializer():
def restore_default():
Tortoise.apps = {}
Tortoise._connections = _CONNECTIONS.copy()
for name in Tortoise._connections.keys():
current_transaction_map[name] = ContextVar(name, default=None)
Tortoise._init_apps(_CONFIG['apps'])
Tortoise._inited = True


def initializer(loop: Optional[BaseSelectorEventLoop] = None) -> None:
"""
Sets up the DB for testing. Must be called as part of test environment setup.
"""
# pylint: disable=W0603
global _CONFIG
global _APPS
global _CONNECTIONS
global _SELECTOR
global _LOOP
_CONFIG = getDBConfig(
app_label='models',
modules=['tortoise.tests.testmodels'],
)

loop = _asyncio.get_event_loop()
loop = loop or asyncio.get_event_loop()
_LOOP = loop
_SELECTOR = loop._selector # type: ignore
loop.run_until_complete(_init_db(_CONFIG))
_APPS = deepcopy(Tortoise.apps)
_CONNECTIONS = Tortoise._connections.copy()
loop.run_until_complete(Tortoise._reset_connections())
Tortoise.apps = {}
Tortoise._connections = {}
Tortoise._inited = False


def finalizer():
def finalizer() -> None:
"""
Cleans up the DB after testing. Must be called as part of the test environment teardown.
"""
Tortoise.apps = deepcopy(_APPS)
Tortoise._connections = _CONNECTIONS.copy()
loop = _asyncio.get_event_loop()
restore_default()
loop = _LOOP
loop._selector = _SELECTOR
loop.run_until_complete(Tortoise._drop_databases())


Expand All @@ -92,6 +107,21 @@ class SimpleTestCase(_TestCase):
Based on `asynctest <http://asynctest.readthedocs.io/>`_
"""
use_default_loop = True

def __init_loop(self):
if self.use_default_loop:
self.loop = _LOOP
loop = None
else:
loop = self.loop = asyncio.new_event_loop()

policy = _Policy(asyncio.get_event_loop_policy(),
loop, self.forbid_get_event_loop)

asyncio.set_event_loop_policy(policy)

self.loop = self._patch_loop(self.loop)

async def _setUpDB(self):
pass
Expand All @@ -109,7 +139,7 @@ def _setUp(self) -> None:
self._checker.before_test(self)

self.loop.run_until_complete(self._setUpDB())
if _asyncio.iscoroutinefunction(self.setUp):
if asyncio.iscoroutinefunction(self.setUp):
self.loop.run_until_complete(self.setUp())
else:
self.setUp()
Expand All @@ -118,11 +148,15 @@ def _setUp(self) -> None:
self.loop._asynctest_ran = False

def _tearDown(self) -> None:
self.loop.run_until_complete(self._tearDownDB())
if _asyncio.iscoroutinefunction(self.tearDown):
if asyncio.iscoroutinefunction(self.tearDown):
self.loop.run_until_complete(self.tearDown())
else:
self.tearDown()
self.loop.run_until_complete(self._tearDownDB())
Tortoise.apps = {}
Tortoise._connections = {}
Tortoise._inited = False
current_transaction_map.clear()

# post-test checks
self._checker.check_test(self)
Expand All @@ -148,8 +182,12 @@ async def _setUpDB(self):
)
await Tortoise.init(config, _create_db=True)
await Tortoise.generate_schemas()
self._connections = Tortoise._connections.copy()

async def _tearDownDB(self) -> None:
Tortoise._connections = self._connections.copy()
for name in Tortoise._connections.keys():
current_transaction_map[name] = ContextVar(name, default=None)
await Tortoise._drop_databases()


Expand All @@ -160,13 +198,9 @@ class TestCase(SimpleTestCase):
"""

async def _setUpDB(self):
Tortoise.apps = deepcopy(_APPS)
Tortoise._connections = _CONNECTIONS.copy()
await Tortoise.init(_CONFIG)

restore_default()
self.transaction = await start_transaction() # pylint: disable=W0201

async def _tearDownDB(self) -> None:
restore_default()
await self.transaction.rollback()
# Have to reset connections because tests are run in different loops
await Tortoise._reset_connections()
10 changes: 5 additions & 5 deletions tortoise/tests/test_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,18 +210,18 @@ async def test_create(self):
obj = await testmodels.DatetimeFields.get(id=obj0.id)
self.assertEqual(obj.datetime, now)
self.assertEqual(obj.datetime_null, None)
self.assertLess(obj.datetime_auto - now, timedelta(seconds=1))
self.assertLess(obj.datetime_add - now, timedelta(seconds=1))
self.assertLess(obj.datetime_auto - now, timedelta(microseconds=10000))
self.assertLess(obj.datetime_add - now, timedelta(microseconds=10000))
datetime_auto = obj.datetime_auto
sleep(1)
sleep(0.011)
await obj.save()
obj2 = await testmodels.DatetimeFields.get(id=obj.id)
self.assertEqual(obj2.datetime, now)
self.assertEqual(obj2.datetime_null, None)
self.assertEqual(obj2.datetime_auto, obj.datetime_auto)
self.assertNotEqual(obj2.datetime_auto, datetime_auto)
self.assertGreater(obj2.datetime_auto - now, timedelta(seconds=1))
self.assertLess(obj2.datetime_auto - now, timedelta(seconds=2))
self.assertGreater(obj2.datetime_auto - now, timedelta(microseconds=10000))
self.assertLess(obj2.datetime_auto - now, timedelta(microseconds=20000))
self.assertEqual(obj2.datetime_add, obj.datetime_add)

async def test_cast(self):
Expand Down
10 changes: 5 additions & 5 deletions tortoise/tests/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
class TestInitErrors(test.SimpleTestCase):
async def setUp(self):
try:
await Tortoise._reset_connections()
Tortoise.apps = {}
Tortoise._connections = {}
Tortoise._inited = False
except ConfigurationError:
pass
Tortoise._inited = False
Expand All @@ -20,10 +22,8 @@ async def setUp(self):
)

async def tearDown(self):
try:
await Tortoise._reset_connections()
except ConfigurationError:
pass
await Tortoise.close_connections()
await Tortoise._reset_apps()

async def test_basic_init(self):
await Tortoise.init({
Expand Down
8 changes: 4 additions & 4 deletions tortoise/transactions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import wraps
from typing import Callable, Optional, Dict # noqa
from typing import Callable, Dict, Optional # noqa

from tortoise.backends.base.client import BaseDBAsyncClient, BaseTransactionWrapper
from tortoise.exceptions import ParamsError
Expand All @@ -15,7 +15,8 @@ def _get_connection(connection_name: Optional[str]) -> BaseDBAsyncClient:
connection = list(Tortoise._connections.values())[0]
else:
raise ParamsError(
'You are running with multiple databases, so you should specify connection_name'
'You are running with multiple databases, so you should specify connection_name: {}'
.format(list(Tortoise._connections.keys()))
)
return connection

Expand All @@ -31,8 +32,7 @@ def in_transaction(connection_name: Optional[str] = None) -> BaseTransactionWrap
one db connection
"""
connection = _get_connection(connection_name)
single_connection = connection._in_transaction()
return single_connection
return connection._in_transaction()


def atomic(connection_name: Optional[str] = None) -> Callable:
Expand Down

0 comments on commit 40f4945

Please sign in to comment.