From 40f4945d67be080a6e473b9794176db8808c0ab5 Mon Sep 17 00:00:00 2001 From: Grigi Date: Mon, 22 Oct 2018 08:32:14 +0200 Subject: [PATCH] Refactor ``Tortoise.init()`` and test runner Does not re-create connections per test, so now tests pass when using an SQLite in-memory database Can pass event loop to test initializer function --- CHANGELOG.rst | 5 ++ Makefile | 2 +- tortoise/__init__.py | 17 +++--- tortoise/backends/asyncpg/client.py | 2 +- tortoise/backends/mysql/client.py | 2 +- tortoise/contrib/test/__init__.py | 84 ++++++++++++++++++++--------- tortoise/tests/test_fields.py | 10 ++-- tortoise/tests/test_init.py | 10 ++-- tortoise/transactions.py | 8 +-- 9 files changed, 89 insertions(+), 51 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 6daf46efb..88027c7e4 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -1,6 +1,11 @@ Changelog ========= +0.10.10 +------- +- Refactor ``Tortoise.init()`` and test runner to not re-create connections per test, so now tests pass when using an SQLite in-memory database +- Can pass event loop to test initializer function: ``initializer(loop=loop)`` + 0.10.9 ------ - Uses macros on SQLite driver to minimise syncronisation. ``aiosqlite>=0.7.0`` diff --git a/Makefile b/Makefile index 0d06bcee2..9c9116429 100644 --- a/Makefile +++ b/Makefile @@ -43,7 +43,7 @@ test: deps testall: deps coverage erase - TORTOISE_TEST_DB=sqlite:///tmp/test-\{\}.sqlite coverage run -p --concurrency=multiprocessing `which green` + TORTOISE_TEST_DB=sqlite://:memory: coverage run -p --concurrency=multiprocessing `which green` TORTOISE_TEST_DB=postgres://postgres:@127.0.0.1:5432/test_\{\} coverage run -p --concurrency=multiprocessing `which green` TORTOISE_TEST_DB="mysql://root:@127.0.0.1:3306/test_\{\}" coverage run -p --concurrency=multiprocessing `which green` coverage combine diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 6dc7df40f..e2569d2c2 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -188,6 +188,10 @@ def _init_apps(cls, apps_config): cls.apps[name] = models_map + cls._init_relations() + + cls._build_initial_querysets() + @classmethod def _get_config_from_config_file(cls, config_file): _, extension = os.path.splitext(config_file) @@ -280,7 +284,8 @@ async def init( For any configuration error """ if cls._inited: - await cls._reset_connections() + await cls.close_connections() + await cls._reset_apps() if int(bool(config) + bool(config_file) + bool(db_url)) != 1: raise ConfigurationError( 'You should init either from "config", "config_file" or "db_url"') @@ -307,10 +312,6 @@ async def init( cls._init_apps(apps_config) - cls._init_relations() - - cls._build_initial_querysets() - cls._inited = True @classmethod @@ -320,9 +321,7 @@ async def close_connections(cls): cls._connections = {} @classmethod - async def _reset_connections(cls): - await cls.close_connections() - + async def _reset_apps(cls): for app in cls.apps.values(): for model in app.values(): model._meta.default_connection = None @@ -353,7 +352,7 @@ async def _drop_databases(cls) -> None: await connection.close() await connection.db_delete() cls._connections = {} - await cls._reset_connections() + await cls._reset_apps() def run_async(coro): diff --git a/tortoise/backends/asyncpg/client.py b/tortoise/backends/asyncpg/client.py index 44db99916..c86dd6f2f 100644 --- a/tortoise/backends/asyncpg/client.py +++ b/tortoise/backends/asyncpg/client.py @@ -1,6 +1,6 @@ import logging from functools import wraps -from typing import List, SupportsInt, Optional # noqa +from typing import List, Optional, SupportsInt # noqa import asyncpg from pypika import PostgreSQLQuery diff --git a/tortoise/backends/mysql/client.py b/tortoise/backends/mysql/client.py index 37af43cba..573c93e4e 100644 --- a/tortoise/backends/mysql/client.py +++ b/tortoise/backends/mysql/client.py @@ -1,6 +1,6 @@ import logging from functools import wraps -from typing import List, SupportsInt, Optional # noqa +from typing import List, Optional, SupportsInt # noqa import aiomysql import pymysql diff --git a/tortoise/contrib/test/__init__.py b/tortoise/contrib/test/__init__.py index 5375147f2..248731e43 100644 --- a/tortoise/contrib/test/__init__.py +++ b/tortoise/contrib/test/__init__.py @@ -1,20 +1,21 @@ -import asyncio as _asyncio +import asyncio import os as _os -from copy import deepcopy -from typing import List +from asyncio.selector_events import BaseSelectorEventLoop +from typing import List, Optional from unittest import SkipTest, expectedFailure, skip, skipIf, skipUnless # noqa from asynctest import TestCase as _TestCase from asynctest import _fail_on +from asynctest.case import _Policy -from tortoise import Tortoise +from tortoise import ContextVar, Tortoise from tortoise.backends.base.config_generator import generate_config as _generate_config from tortoise.exceptions import DBConnectionError -from tortoise.transactions import start_transaction +from tortoise.transactions import current_transaction_map, start_transaction __all__ = ('SimpleTestCase', 'IsolatedTestCase', 'TestCase', 'SkipTest', 'expectedFailure', 'skip', 'skipIf', 'skipUnless', 'initializer', 'finalizer') -_TORTOISE_TEST_DB = _os.environ.get('TORTOISE_TEST_DB', 'sqlite:///tmp/test-{}.sqlite') +_TORTOISE_TEST_DB = _os.environ.get('TORTOISE_TEST_DB', 'sqlite://:memory:') expectedFailure.__doc__ = """ Mark test as expecting failiure. @@ -23,8 +24,9 @@ """ _CONFIG = {} # type: dict -_APPS = {} # type: dict _CONNECTIONS = {} # type: dict +_SELECTOR = None # type: ignore +_LOOP = None # type: BaseSelectorEventLoop def getDBConfig(app_label: str, modules: List[str]) -> dict: @@ -52,33 +54,46 @@ async def _init_db(config): await Tortoise.generate_schemas() -def initializer(): +def restore_default(): + Tortoise.apps = {} + Tortoise._connections = _CONNECTIONS.copy() + for name in Tortoise._connections.keys(): + current_transaction_map[name] = ContextVar(name, default=None) + Tortoise._init_apps(_CONFIG['apps']) + Tortoise._inited = True + + +def initializer(loop: Optional[BaseSelectorEventLoop] = None) -> None: """ Sets up the DB for testing. Must be called as part of test environment setup. """ # pylint: disable=W0603 global _CONFIG - global _APPS global _CONNECTIONS + global _SELECTOR + global _LOOP _CONFIG = getDBConfig( app_label='models', modules=['tortoise.tests.testmodels'], ) - loop = _asyncio.get_event_loop() + loop = loop or asyncio.get_event_loop() + _LOOP = loop + _SELECTOR = loop._selector # type: ignore loop.run_until_complete(_init_db(_CONFIG)) - _APPS = deepcopy(Tortoise.apps) _CONNECTIONS = Tortoise._connections.copy() - loop.run_until_complete(Tortoise._reset_connections()) + Tortoise.apps = {} + Tortoise._connections = {} + Tortoise._inited = False -def finalizer(): +def finalizer() -> None: """ Cleans up the DB after testing. Must be called as part of the test environment teardown. """ - Tortoise.apps = deepcopy(_APPS) - Tortoise._connections = _CONNECTIONS.copy() - loop = _asyncio.get_event_loop() + restore_default() + loop = _LOOP + loop._selector = _SELECTOR loop.run_until_complete(Tortoise._drop_databases()) @@ -92,6 +107,21 @@ class SimpleTestCase(_TestCase): Based on `asynctest `_ """ + use_default_loop = True + + def __init_loop(self): + if self.use_default_loop: + self.loop = _LOOP + loop = None + else: + loop = self.loop = asyncio.new_event_loop() + + policy = _Policy(asyncio.get_event_loop_policy(), + loop, self.forbid_get_event_loop) + + asyncio.set_event_loop_policy(policy) + + self.loop = self._patch_loop(self.loop) async def _setUpDB(self): pass @@ -109,7 +139,7 @@ def _setUp(self) -> None: self._checker.before_test(self) self.loop.run_until_complete(self._setUpDB()) - if _asyncio.iscoroutinefunction(self.setUp): + if asyncio.iscoroutinefunction(self.setUp): self.loop.run_until_complete(self.setUp()) else: self.setUp() @@ -118,11 +148,15 @@ def _setUp(self) -> None: self.loop._asynctest_ran = False def _tearDown(self) -> None: - self.loop.run_until_complete(self._tearDownDB()) - if _asyncio.iscoroutinefunction(self.tearDown): + if asyncio.iscoroutinefunction(self.tearDown): self.loop.run_until_complete(self.tearDown()) else: self.tearDown() + self.loop.run_until_complete(self._tearDownDB()) + Tortoise.apps = {} + Tortoise._connections = {} + Tortoise._inited = False + current_transaction_map.clear() # post-test checks self._checker.check_test(self) @@ -148,8 +182,12 @@ async def _setUpDB(self): ) await Tortoise.init(config, _create_db=True) await Tortoise.generate_schemas() + self._connections = Tortoise._connections.copy() async def _tearDownDB(self) -> None: + Tortoise._connections = self._connections.copy() + for name in Tortoise._connections.keys(): + current_transaction_map[name] = ContextVar(name, default=None) await Tortoise._drop_databases() @@ -160,13 +198,9 @@ class TestCase(SimpleTestCase): """ async def _setUpDB(self): - Tortoise.apps = deepcopy(_APPS) - Tortoise._connections = _CONNECTIONS.copy() - await Tortoise.init(_CONFIG) - + restore_default() self.transaction = await start_transaction() # pylint: disable=W0201 async def _tearDownDB(self) -> None: + restore_default() await self.transaction.rollback() - # Have to reset connections because tests are run in different loops - await Tortoise._reset_connections() diff --git a/tortoise/tests/test_fields.py b/tortoise/tests/test_fields.py index 5fa3d8c26..94133e947 100644 --- a/tortoise/tests/test_fields.py +++ b/tortoise/tests/test_fields.py @@ -210,18 +210,18 @@ async def test_create(self): obj = await testmodels.DatetimeFields.get(id=obj0.id) self.assertEqual(obj.datetime, now) self.assertEqual(obj.datetime_null, None) - self.assertLess(obj.datetime_auto - now, timedelta(seconds=1)) - self.assertLess(obj.datetime_add - now, timedelta(seconds=1)) + self.assertLess(obj.datetime_auto - now, timedelta(microseconds=10000)) + self.assertLess(obj.datetime_add - now, timedelta(microseconds=10000)) datetime_auto = obj.datetime_auto - sleep(1) + sleep(0.011) await obj.save() obj2 = await testmodels.DatetimeFields.get(id=obj.id) self.assertEqual(obj2.datetime, now) self.assertEqual(obj2.datetime_null, None) self.assertEqual(obj2.datetime_auto, obj.datetime_auto) self.assertNotEqual(obj2.datetime_auto, datetime_auto) - self.assertGreater(obj2.datetime_auto - now, timedelta(seconds=1)) - self.assertLess(obj2.datetime_auto - now, timedelta(seconds=2)) + self.assertGreater(obj2.datetime_auto - now, timedelta(microseconds=10000)) + self.assertLess(obj2.datetime_auto - now, timedelta(microseconds=20000)) self.assertEqual(obj2.datetime_add, obj.datetime_add) async def test_cast(self): diff --git a/tortoise/tests/test_init.py b/tortoise/tests/test_init.py index d2647f12f..9d4426a0b 100644 --- a/tortoise/tests/test_init.py +++ b/tortoise/tests/test_init.py @@ -10,7 +10,9 @@ class TestInitErrors(test.SimpleTestCase): async def setUp(self): try: - await Tortoise._reset_connections() + Tortoise.apps = {} + Tortoise._connections = {} + Tortoise._inited = False except ConfigurationError: pass Tortoise._inited = False @@ -20,10 +22,8 @@ async def setUp(self): ) async def tearDown(self): - try: - await Tortoise._reset_connections() - except ConfigurationError: - pass + await Tortoise.close_connections() + await Tortoise._reset_apps() async def test_basic_init(self): await Tortoise.init({ diff --git a/tortoise/transactions.py b/tortoise/transactions.py index 728a45b6c..a5e4e04c5 100644 --- a/tortoise/transactions.py +++ b/tortoise/transactions.py @@ -1,5 +1,5 @@ from functools import wraps -from typing import Callable, Optional, Dict # noqa +from typing import Callable, Dict, Optional # noqa from tortoise.backends.base.client import BaseDBAsyncClient, BaseTransactionWrapper from tortoise.exceptions import ParamsError @@ -15,7 +15,8 @@ def _get_connection(connection_name: Optional[str]) -> BaseDBAsyncClient: connection = list(Tortoise._connections.values())[0] else: raise ParamsError( - 'You are running with multiple databases, so you should specify connection_name' + 'You are running with multiple databases, so you should specify connection_name: {}' + .format(list(Tortoise._connections.keys())) ) return connection @@ -31,8 +32,7 @@ def in_transaction(connection_name: Optional[str] = None) -> BaseTransactionWrap one db connection """ connection = _get_connection(connection_name) - single_connection = connection._in_transaction() - return single_connection + return connection._in_transaction() def atomic(connection_name: Optional[str] = None) -> Callable: