Skip to content

Commit

Permalink
Provide tx_mode to one-time transactions
Browse files Browse the repository at this point in the history
  • Loading branch information
LuckySting committed Feb 2, 2024
1 parent 5c4d6ca commit 4170e5f
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 9 deletions.
61 changes: 58 additions & 3 deletions test_dbapi/test_dbapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,28 @@


class BaseDBApiTestSuit:
def _test_isolation_level_read_only(self, connection: dbapi.Connection, isolation_level: str, read_only: bool):
connection.cursor().execute(
dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True)
)
connection.set_isolation_level(isolation_level)

cursor = connection.cursor()

connection.begin()

query = dbapi.YdbQuery("UPSERT INTO foo(id) VALUES (1)")
if read_only:
with pytest.raises(dbapi.DatabaseError):
cursor.execute(query)
else:
cursor.execute(query)

connection.rollback()

connection.cursor().execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True))
connection.cursor().close()

def _test_connection(self, connection: dbapi.Connection):
connection.commit()
connection.rollback()
Expand Down Expand Up @@ -100,14 +122,28 @@ def _test_errors(self, connection: dbapi.Connection):


class TestSyncConnection(BaseDBApiTestSuit):
@pytest.fixture(scope="class")
@pytest.fixture
def sync_connection(self) -> dbapi.Connection:
conn = dbapi.YdbDBApi().connect(host="localhost", port="2136", database="/local")
try:
yield conn
finally:
conn.close()

@pytest.mark.parametrize(
"isolation_level, read_only",
[
(dbapi.IsolationLevel.SERIALIZABLE, False),
(dbapi.IsolationLevel.AUTOCOMMIT, False),
(dbapi.IsolationLevel.ONLINE_READONLY, True),
(dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True),
(dbapi.IsolationLevel.STALE_READONLY, True),
(dbapi.IsolationLevel.SNAPSHOT_READONLY, True),
],
)
def test_isolation_level_read_only(self, isolation_level: str, read_only: bool, sync_connection: dbapi.Connection):
self._test_isolation_level_read_only(sync_connection, isolation_level, read_only)

def test_connection(self, sync_connection: dbapi.Connection):
self._test_connection(sync_connection)

Expand All @@ -118,9 +154,8 @@ def test_errors(self, sync_connection: dbapi.Connection):
return self._test_errors(sync_connection)


@pytest.mark.asyncio(scope="class")
class TestAsyncConnection(BaseDBApiTestSuit):
@pytest_asyncio.fixture(scope="class")
@pytest_asyncio.fixture
async def async_connection(self) -> dbapi.AsyncConnection:
def connect():
return dbapi.YdbDBApi().async_connect(host="localhost", port="2136", database="/local")
Expand All @@ -131,11 +166,31 @@ def connect():
finally:
await util.greenlet_spawn(conn.close)

@pytest.mark.asyncio
@pytest.mark.parametrize(
"isolation_level, read_only",
[
(dbapi.IsolationLevel.SERIALIZABLE, False),
(dbapi.IsolationLevel.AUTOCOMMIT, False),
(dbapi.IsolationLevel.ONLINE_READONLY, True),
(dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True),
(dbapi.IsolationLevel.STALE_READONLY, True),
(dbapi.IsolationLevel.SNAPSHOT_READONLY, True),
],
)
async def test_isolation_level_read_only(
self, isolation_level: str, read_only: bool, async_connection: dbapi.AsyncConnection
):
await util.greenlet_spawn(self._test_isolation_level_read_only, async_connection, isolation_level, read_only)

@pytest.mark.asyncio
async def test_connection(self, async_connection: dbapi.AsyncConnection):
await util.greenlet_spawn(self._test_connection, async_connection)

@pytest.mark.asyncio
async def test_cursor_raw_query(self, async_connection: dbapi.AsyncConnection):
await util.greenlet_spawn(self._test_cursor_raw_query, async_connection)

@pytest.mark.asyncio
async def test_errors(self, async_connection: dbapi.AsyncConnection):
await util.greenlet_spawn(self._test_errors, async_connection)
2 changes: 1 addition & 1 deletion ydb_sqlalchemy/dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.tx_context: Optional[ydb.TxContext] = None

def cursor(self):
return self._cursor_class(self.session_pool, self.tx_context)
return self._cursor_class(self.session_pool, self.tx_mode, self.tx_context)

def describe(self, table_path: str) -> ydb.TableDescription:
abs_table_path = posixpath.join(self.database, table_path)
Expand Down
18 changes: 13 additions & 5 deletions ydb_sqlalchemy/dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,11 @@ class Cursor:
def __init__(
self,
session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool],
tx_mode: ydb.AbstractTransactionModeBuilder,
tx_context: Optional[ydb.BaseTxContext] = None,
):
self.session_pool = session_pool
self.tx_mode = tx_mode
self.tx_context = tx_context
self.description = None
self.arraysize = 1
Expand Down Expand Up @@ -142,7 +144,7 @@ def _execute_dml(
if self.tx_context:
return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters)

return self._retry_operation_in_pool(self._execute_in_session, prepared_query, parameters)
return self._retry_operation_in_pool(self._execute_in_session, self.tx_mode, prepared_query, parameters)

@_handle_ydb_errors
def _execute_ddl(self, query: str) -> ydb.convert.ResultSets:
Expand Down Expand Up @@ -176,9 +178,12 @@ def _execute_in_tx(

@staticmethod
def _execute_in_session(
session: ydb.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
session: ydb.Session,
tx_mode: ydb.AbstractTransactionModeBuilder,
prepared_query: ydb.DataQuery,
parameters: Optional[Mapping[str, Any]],
) -> ydb.convert.ResultSets:
return session.transaction().execute(prepared_query, parameters, commit_tx=True)
return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)

def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs):
return callee(self.tx_context, *args, **kwargs)
Expand Down Expand Up @@ -282,9 +287,12 @@ async def _execute_in_tx(

@staticmethod
async def _execute_in_session(
session: ydb.aio.table.Session, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
session: ydb.aio.table.Session,
tx_mode: ydb.AbstractTransactionModeBuilder,
prepared_query: ydb.DataQuery,
parameters: Optional[Mapping[str, Any]],
) -> ydb.convert.ResultSets:
return await session.transaction().execute(prepared_query, parameters, commit_tx=True)
return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)

def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs):
return self._await(callee(self.tx_context, *args, **kwargs))
Expand Down

0 comments on commit 4170e5f

Please sign in to comment.