diff --git a/tests/test_concurrency.py b/tests/test_concurrency.py index ca1e5ab0d..cc565b677 100644 --- a/tests/test_concurrency.py +++ b/tests/test_concurrency.py @@ -2,6 +2,7 @@ import sys from tests.testmodels import Tournament, UniqueName +from tortoise import Tortoise, connections from tortoise.contrib import test from tortoise.contrib.test.condition import NotEQ from tortoise.transactions import in_transaction @@ -89,3 +90,26 @@ async def test_nonconcurrent_get_or_create(self): self.assertEqual(len(una_created), 1) for una in unas: self.assertEqual(una[0], unas[0][0]) + + +class TestConcurrentDBConnectionInitialization(test.IsolatedTestCase): + """Tortoise.init is lazy and does not initialize the database connection until the first query. + These tests ensure that concurrent queries do not cause initialization issues.""" + + async def _setUpDB(self) -> None: + """Override to avoid database connection initialization when generating the schema.""" + await super()._setUpDB() + config = test.getDBConfig(app_label="models", modules=test._MODULES) + await Tortoise.init(config, _create_db=True) + + async def test_concurrent_queries(self): + await asyncio.gather( + *[connections.get("models").execute_query("SELECT 1") for _ in range(100)] + ) + + async def test_concurrent_transactions(self) -> None: + async def transaction() -> None: + async with in_transaction(): + await connections.get("models").execute_query("SELECT 1") + + await asyncio.gather(*[transaction() for _ in range(100)]) diff --git a/tortoise/__init__.py b/tortoise/__init__.py index 2aace9a16..575304fb2 100644 --- a/tortoise/__init__.py +++ b/tortoise/__init__.py @@ -399,7 +399,9 @@ async def init( table_name_generator: Callable[[Type["Model"]], str] | None = None, ) -> None: """ - Sets up Tortoise-ORM. + Sets up Tortoise-ORM: loads apps and models, configures database connections but does not + connect to the database yet. The actual connection or connection pool is established + lazily on first query execution. You can configure using only one of ``config``, ``config_file`` and ``(db_url, modules)``. diff --git a/tortoise/backends/asyncpg/client.py b/tortoise/backends/asyncpg/client.py index a92bedaa1..88a9d8522 100644 --- a/tortoise/backends/asyncpg/client.py +++ b/tortoise/backends/asyncpg/client.py @@ -100,10 +100,10 @@ async def db_delete(self) -> None: await self.close() def acquire_connection(self) -> Union["PoolConnectionWrapper", "ConnectionWrapper"]: - return PoolConnectionWrapper(self) + return PoolConnectionWrapper(self, self._pool_init_lock) def _in_transaction(self) -> "TransactionContext": - return TransactionContextPooled(TransactionWrapper(self)) + return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock) @translate_exceptions async def execute_insert(self, query: str, values: list) -> Optional[asyncpg.Record]: diff --git a/tortoise/backends/base/client.py b/tortoise/backends/base/client.py index 79583642f..ea08e88d9 100644 --- a/tortoise/backends/base/client.py +++ b/tortoise/backends/base/client.py @@ -222,10 +222,10 @@ class ConnectionWrapper(Generic[T_conn]): """Wraps the connections with a lock to facilitate safe concurrent access when using asyncio.gather, TaskGroup, or similar.""" - __slots__ = ("connection", "lock", "client") + __slots__ = ("connection", "_lock", "client") def __init__(self, lock: asyncio.Lock, client: Any) -> None: - self.lock: asyncio.Lock = lock + self._lock: asyncio.Lock = lock self.client = client self.connection: T_conn = client._connection @@ -235,12 +235,12 @@ async def ensure_connection(self) -> None: self.connection = self.client._connection async def __aenter__(self) -> T_conn: - await self.lock.acquire() + await self._lock.acquire() await self.ensure_connection() return self.connection async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: - self.lock.release() + self._lock.release() class TransactionContext(Generic[T_conn]): @@ -259,15 +259,19 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... class TransactionContextPooled(TransactionContext): "A version of TransactionContext that uses a pool to acquire connections." - __slots__ = ("conn_wrapper", "connection", "connection_name", "token") + __slots__ = ("connection", "connection_name", "token", "_pool_init_lock") - def __init__(self, connection: Any) -> None: + def __init__(self, connection: Any, pool_init_lock: asyncio.Lock) -> None: self.connection = connection self.connection_name = connection.connection_name + self._pool_init_lock = pool_init_lock async def ensure_connection(self) -> None: if not self.connection._parent._pool: - await self.connection._parent.create_connection(with_db=True) + # a safeguard against multiple concurrent tasks trying to initialize the pool + async with self._pool_init_lock: + if not self.connection._parent._pool: + await self.connection._parent.create_connection(with_db=True) async def __aenter__(self) -> T_conn: await self.ensure_connection() @@ -315,25 +319,27 @@ async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: class PoolConnectionWrapper(Generic[T_conn]): """Class to manage acquiring from and releasing connections to a pool.""" - def __init__(self, client: Any) -> None: - self.pool = client._pool + def __init__(self, client: Any, pool_init_lock: asyncio.Lock) -> None: self.client = client self.connection: Optional[T_conn] = None + self._pool_init_lock = pool_init_lock async def ensure_connection(self) -> None: - if not self.pool: - await self.client.create_connection(with_db=True) - self.pool = self.client._pool + if not self.client._pool: + # a safeguard against multiple concurrent tasks trying to initialize the pool + async with self._pool_init_lock: + if not self.client._pool: + await self.client.create_connection(with_db=True) async def __aenter__(self) -> T_conn: await self.ensure_connection() # get first available connection. If none available, wait until one is released - self.connection = await self.pool.acquire() + self.connection = await self.client._pool.acquire() return cast(T_conn, self.connection) async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # release the connection back to the pool - await self.pool.release(self.connection) + await self.client._pool.release(self.connection) class BaseTransactionWrapper: diff --git a/tortoise/backends/base_postgres/client.py b/tortoise/backends/base_postgres/client.py index 9fdb39355..73131e4c4 100644 --- a/tortoise/backends/base_postgres/client.py +++ b/tortoise/backends/base_postgres/client.py @@ -1,4 +1,5 @@ import abc +import asyncio from asyncio.events import AbstractEventLoop from functools import wraps from typing import ( @@ -90,6 +91,7 @@ def __init__( self._template: dict = {} self._pool = None self._connection = None + self._pool_init_lock = asyncio.Lock() @abc.abstractmethod async def create_connection(self, with_db: bool) -> None: @@ -128,7 +130,7 @@ async def db_delete(self) -> None: await self.close() def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: - return PoolConnectionWrapper(self._pool) + return PoolConnectionWrapper(self._pool, self._pool_init_lock) @abc.abstractmethod def _in_transaction(self) -> "TransactionContext": diff --git a/tortoise/backends/mssql/client.py b/tortoise/backends/mssql/client.py index 424f5e759..db31e36f1 100644 --- a/tortoise/backends/mssql/client.py +++ b/tortoise/backends/mssql/client.py @@ -41,7 +41,7 @@ def __init__( self.dsn = f"DRIVER={driver};SERVER={host},{port};UID={user};PWD={password};" def _in_transaction(self) -> "TransactionContext": - return TransactionContextPooled(TransactionWrapper(self)) + return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: diff --git a/tortoise/backends/mysql/client.py b/tortoise/backends/mysql/client.py index d8b9e2118..6ad9f57a4 100644 --- a/tortoise/backends/mysql/client.py +++ b/tortoise/backends/mysql/client.py @@ -108,6 +108,7 @@ def __init__( self._template: dict = {} self._pool: Optional[mysql.Pool] = None self._connection = None + self._pool_init_lock = asyncio.Lock() async def create_connection(self, with_db: bool) -> None: if charset_by_name(self.charset) is None: @@ -172,10 +173,10 @@ async def db_delete(self) -> None: await self.close() def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: - return PoolConnectionWrapper(self) + return PoolConnectionWrapper(self, self._pool_init_lock) def _in_transaction(self) -> "TransactionContext": - return TransactionContextPooled(TransactionWrapper(self)) + return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock) @translate_exceptions async def execute_insert(self, query: str, values: list) -> int: diff --git a/tortoise/backends/odbc/client.py b/tortoise/backends/odbc/client.py index 39bf16099..aac6f6306 100644 --- a/tortoise/backends/odbc/client.py +++ b/tortoise/backends/odbc/client.py @@ -70,6 +70,7 @@ def __init__( self._template: dict = {} self._pool: Optional[asyncodbc.Pool] = None self._connection = None + self._pool_init_lock = asyncio.Lock() async def create_connection(self, with_db: bool) -> None: self._template = { @@ -114,7 +115,7 @@ async def close(self) -> None: self._pool = None def acquire_connection(self) -> ConnWrapperType: - return PoolConnectionWrapper(self) + return PoolConnectionWrapper(self, self._pool_init_lock) @translate_exceptions async def execute_many(self, query: str, values: list) -> None: diff --git a/tortoise/backends/oracle/client.py b/tortoise/backends/oracle/client.py index 153865a0f..b561a9f5d 100644 --- a/tortoise/backends/oracle/client.py +++ b/tortoise/backends/oracle/client.py @@ -57,10 +57,10 @@ def __init__( self.dsn = f"DRIVER={driver};DBQ={dbq};UID={user};PWD={password};" def _in_transaction(self) -> "TransactionContext": - return TransactionContextPooled(TransactionWrapper(self)) + return TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock) def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: - return OraclePoolConnectionWrapper(self) + return OraclePoolConnectionWrapper(self, self._pool_init_lock) async def db_create(self) -> None: await self.create_connection(with_db=False) diff --git a/tortoise/backends/psycopg/client.py b/tortoise/backends/psycopg/client.py index 03e3e0d6e..b73d57f49 100644 --- a/tortoise/backends/psycopg/client.py +++ b/tortoise/backends/psycopg/client.py @@ -193,10 +193,10 @@ async def _translate_exceptions(self, func, *args, **kwargs) -> Exception: def acquire_connection( self, ) -> typing.Union[base_client.ConnectionWrapper, PoolConnectionWrapper]: - return PoolConnectionWrapper(self) + return PoolConnectionWrapper(self, self._pool_init_lock) def _in_transaction(self) -> base_client.TransactionContext: - return base_client.TransactionContextPooled(TransactionWrapper(self)) + return base_client.TransactionContextPooled(TransactionWrapper(self), self._pool_init_lock) class TransactionWrapper(PsycopgClient, base_client.BaseTransactionWrapper): diff --git a/tortoise/backends/sqlite/client.py b/tortoise/backends/sqlite/client.py index 6045e9f39..7372819fe 100644 --- a/tortoise/backends/sqlite/client.py +++ b/tortoise/backends/sqlite/client.py @@ -190,8 +190,8 @@ async def ensure_connection(self) -> None: self.connection._connection = self.connection._parent._connection async def __aenter__(self) -> T_conn: - await self.ensure_connection() await self._trxlock.acquire() + await self.ensure_connection() self.token = connections.set(self.connection_name, self.connection) await self.connection.begin() return self.connection