diff --git a/test_dbapi/test_dbapi.py b/test_dbapi/test_dbapi.py index 6f01ae0..08e4a9b 100644 --- a/test_dbapi/test_dbapi.py +++ b/test_dbapi/test_dbapi.py @@ -9,6 +9,28 @@ class BaseDBApiTestSuit: + def _test_isolation_level_read_only(self, connection: dbapi.Connection, isolation_level: str, read_only: bool): + connection.cursor().execute( + dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True) + ) + connection.set_isolation_level(isolation_level) + + cursor = connection.cursor() + + connection.begin() + + query = dbapi.YdbQuery("UPSERT INTO foo(id) VALUES (1)") + if read_only: + with pytest.raises(dbapi.DatabaseError): + cursor.execute(query) + else: + cursor.execute(query) + + connection.rollback() + + connection.cursor().execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) + connection.cursor().close() + def _test_connection(self, connection: dbapi.Connection): connection.commit() connection.rollback() @@ -100,7 +122,7 @@ def _test_errors(self, connection: dbapi.Connection): class TestSyncConnection(BaseDBApiTestSuit): - @pytest.fixture(scope="class") + @pytest.fixture def sync_connection(self) -> dbapi.Connection: conn = dbapi.YdbDBApi().connect(host="localhost", port="2136", database="/local") try: @@ -108,6 +130,20 @@ def sync_connection(self) -> dbapi.Connection: finally: conn.close() + @pytest.mark.parametrize( + "isolation_level, read_only", + [ + (dbapi.IsolationLevel.SERIALIZABLE, False), + (dbapi.IsolationLevel.AUTOCOMMIT, False), + (dbapi.IsolationLevel.ONLINE_READONLY, True), + (dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True), + (dbapi.IsolationLevel.STALE_READONLY, True), + (dbapi.IsolationLevel.SNAPSHOT_READONLY, True), + ], + ) + def test_isolation_level_read_only(self, isolation_level: str, read_only: bool, sync_connection: dbapi.Connection): + self._test_isolation_level_read_only(sync_connection, isolation_level, read_only) + def test_connection(self, sync_connection: dbapi.Connection): self._test_connection(sync_connection) @@ -118,9 +154,8 @@ def test_errors(self, sync_connection: dbapi.Connection): return self._test_errors(sync_connection) -@pytest.mark.asyncio(scope="class") class TestAsyncConnection(BaseDBApiTestSuit): - @pytest_asyncio.fixture(scope="class") + @pytest_asyncio.fixture async def async_connection(self) -> dbapi.AsyncConnection: def connect(): return dbapi.YdbDBApi().async_connect(host="localhost", port="2136", database="/local") @@ -131,11 +166,31 @@ def connect(): finally: await util.greenlet_spawn(conn.close) + @pytest.mark.asyncio + @pytest.mark.parametrize( + "isolation_level, read_only", + [ + (dbapi.IsolationLevel.SERIALIZABLE, False), + (dbapi.IsolationLevel.AUTOCOMMIT, False), + (dbapi.IsolationLevel.ONLINE_READONLY, True), + (dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True), + (dbapi.IsolationLevel.STALE_READONLY, True), + (dbapi.IsolationLevel.SNAPSHOT_READONLY, True), + ], + ) + async def test_isolation_level_read_only( + self, isolation_level: str, read_only: bool, async_connection: dbapi.AsyncConnection + ): + await util.greenlet_spawn(self._test_isolation_level_read_only, async_connection, isolation_level, read_only) + + @pytest.mark.asyncio async def test_connection(self, async_connection: dbapi.AsyncConnection): await util.greenlet_spawn(self._test_connection, async_connection) + @pytest.mark.asyncio async def test_cursor_raw_query(self, async_connection: dbapi.AsyncConnection): await util.greenlet_spawn(self._test_cursor_raw_query, async_connection) + @pytest.mark.asyncio async def test_errors(self, async_connection: dbapi.AsyncConnection): await util.greenlet_spawn(self._test_errors, async_connection) diff --git a/ydb_sqlalchemy/dbapi/connection.py b/ydb_sqlalchemy/dbapi/connection.py index e198924..a90239f 100644 --- a/ydb_sqlalchemy/dbapi/connection.py +++ b/ydb_sqlalchemy/dbapi/connection.py @@ -57,7 +57,7 @@ def __init__( self.tx_context: Optional[ydb.TxContext] = None def cursor(self): - return self._cursor_class(self.session_pool, self.tx_context) + return self._cursor_class(self.session_pool, self.tx_mode, self.tx_context) def describe(self, table_path: str) -> ydb.TableDescription: abs_table_path = posixpath.join(self.database, table_path) diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py index 4d2a038..9d74424 100644 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ b/ydb_sqlalchemy/dbapi/cursor.py @@ -76,9 +76,11 @@ class Cursor: def __init__( self, session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool], + tx_mode: ydb.AbstractTransactionModeBuilder, tx_context: Optional[ydb.BaseTxContext] = None, ): self.session_pool = session_pool + self.tx_mode = tx_mode self.tx_context = tx_context self.description = None self.arraysize = 1 @@ -142,7 +144,7 @@ def _execute_dml( if self.tx_context: return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters) - return self._retry_operation_in_pool(self._execute_in_session, prepared_query, parameters) + return self._retry_operation_in_pool(self._execute_in_session, self.tx_mode, prepared_query, parameters) @_handle_ydb_errors def _execute_ddl(self, query: str) -> ydb.convert.ResultSets: @@ -176,9 +178,12 @@ def _execute_in_tx( @staticmethod def _execute_in_session( - session: ydb.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + session: ydb.Session, + tx_mode: ydb.AbstractTransactionModeBuilder, + prepared_query: ydb.DataQuery, + parameters: Optional[Mapping[str, Any]], ) -> ydb.convert.ResultSets: - return session.transaction().execute(prepared_query, parameters, commit_tx=True) + return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs): return callee(self.tx_context, *args, **kwargs) @@ -282,9 +287,12 @@ async def _execute_in_tx( @staticmethod async def _execute_in_session( - session: ydb.aio.table.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + session: ydb.aio.table.Session, + tx_mode: ydb.AbstractTransactionModeBuilder, + prepared_query: ydb.DataQuery, + parameters: Optional[Mapping[str, Any]], ) -> ydb.convert.ResultSets: - return await session.transaction().execute(prepared_query, parameters, commit_tx=True) + return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs): return self._await(callee(self.tx_context, *args, **kwargs))