Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added asyncio support for sqlalchemy #72

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 65 additions & 26 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import datetime
from decimal import Decimal
from typing import NamedTuple

from sqlalchemy.schema import CreateTable, DropTable
import pytest
import sqlalchemy as sa
import ydb
from sqlalchemy import Column, Integer, String, Table, Unicode
from sqlalchemy import Column, Integer, String, Table, Unicode, insert, select
from sqlalchemy.ext.asyncio import async_sessionmaker
from sqlalchemy.testing import async_test
from sqlalchemy.testing.fixtures import TablesTest, TestBase, config
from ydb._grpc.v4.protos import ydb_common_pb2

Expand Down Expand Up @@ -380,11 +382,11 @@ def test_auto_partitioning_partition_size_mb(self, connection, auto_partitioning
],
)
def test_auto_partitioning_min_partitions_count(
self,
connection,
auto_partitioning_min_partitions_count,
res,
metadata,
self,
connection,
auto_partitioning_min_partitions_count,
res,
metadata,
):
desc = self._create_table_and_get_desc(
connection,
Expand All @@ -401,11 +403,11 @@ def test_auto_partitioning_min_partitions_count(
],
)
def test_auto_partitioning_max_partitions_count(
self,
connection,
auto_partitioning_max_partitions_count,
res,
metadata,
self,
connection,
auto_partitioning_max_partitions_count,
res,
metadata,
):
desc = self._create_table_and_get_desc(
connection,
Expand All @@ -422,11 +424,11 @@ def test_auto_partitioning_max_partitions_count(
],
)
def test_uniform_partitions(
self,
connection,
uniform_partitions,
res,
metadata,
self,
connection,
uniform_partitions,
res,
metadata,
):
desc = self._create_table_and_get_desc(
connection,
Expand All @@ -444,11 +446,11 @@ def test_uniform_partitions(
],
)
def test_partition_at_keys(
self,
connection,
partition_at_keys,
res,
metadata,
self,
connection,
partition_at_keys,
res,
metadata,
):
desc = self._create_table_and_get_desc(
connection,
Expand Down Expand Up @@ -535,10 +537,10 @@ def test_interactive_transaction(self, connection_no_trans, connection, isolatio
@pytest.mark.parametrize(
"isolation_level",
(
IsolationLevel.ONLINE_READONLY,
IsolationLevel.ONLINE_READONLY_INCONSISTENT,
IsolationLevel.STALE_READONLY,
IsolationLevel.AUTOCOMMIT,
IsolationLevel.ONLINE_READONLY,
IsolationLevel.ONLINE_READONLY_INCONSISTENT,
IsolationLevel.STALE_READONLY,
IsolationLevel.AUTOCOMMIT,
),
)
def test_not_interactive_transaction(self, connection_no_trans, connection, isolation_level):
Expand Down Expand Up @@ -673,6 +675,13 @@ def ydb_pool(self, ydb_driver):
finally:
loop.run_until_complete(session_pool.stop())

@pytest.mark.asyncio
async def test_crud_commands_on_session(self, async_testing_engine):
engine = async_testing_engine()
maker = async_sessionmaker(engine)
async with maker() as session:
await session.execute(sa.text("SELECT 1 as value"))


class TestCredentials(TestBase):
__backend__ = True
Expand Down Expand Up @@ -725,6 +734,36 @@ def test_ydb_credentials_bad(self, query_client_settings, driver_config_for_cred
assert "Invalid password" in str(excinfo.value)


class TestAsyncCRUD(TestBase):
__only_on__ = "yql+ydb_async"

@async_test
async def test_crud(self, async_testing_engine, metadata):
engine = async_testing_engine()
maker = async_sessionmaker(engine)
async with maker() as session:
res = await session.scalar(sa.text("SELECT 1 as value"))

assert res == 1
table = Table(
'test',
metadata,
Column("id", Integer, primary_key=True),
Column("name", String),
)

async with maker() as session:
await session.execute(CreateTable(table))

# t1 = TestModel(id=1, name="test")
stmt = insert(table).values(id=1, name="test")
await session.execute(stmt)
stmt = select(table).where(table.c.id == 1)
t2 = await session.scalar(stmt)
assert t2 == 1
await session.execute(DropTable(table))


class TestUpsert(TablesTest):
__backend__ = True

Expand Down
104 changes: 69 additions & 35 deletions ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

import collections
import collections.abc
from typing import Any, Mapping, Optional, Sequence, Tuple, Union
from typing import Any, Mapping, Optional, Sequence, Tuple, Union, Type

import sqlalchemy as sa
import ydb
from sqlalchemy import util
from sqlalchemy import util, AsyncAdaptedQueuePool, URL, Pool
from sqlalchemy.engine import characteristics, reflection
from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect
from sqlalchemy.exc import NoSuchTableError
Expand All @@ -22,10 +22,10 @@
from ydb_sqlalchemy.sqlalchemy.dml import Upsert

from ydb_sqlalchemy.sqlalchemy.compiler import YqlCompiler, YqlDDLCompiler, YqlIdentifierPreparer, YqlTypeCompiler

from ydb_sqlalchemy.sqlalchemy.dbapi_adapter import AdaptedAsyncCursor
from ydb_dbapi.utils import CursorStatus
from . import types


OLD_SA = sa.__version__ < "2."


Expand Down Expand Up @@ -86,12 +86,12 @@ def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: ydb_dbap
dialect.reset_ydb_request_settings(dbapi_connection)

def set_characteristic(
self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection, value: ydb.BaseRequestSettings
self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection, value: ydb.BaseRequestSettings
) -> None:
dialect.set_ydb_request_settings(dbapi_connection, value)

def get_characteristic(
self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection
self, dialect: "YqlDialect", dbapi_connection: ydb_dbapi.Connection
) -> ydb.BaseRequestSettings:
return dialect.get_ydb_request_settings(dbapi_connection)

Expand Down Expand Up @@ -179,11 +179,11 @@ def dbapi(cls):
return cls.import_dbapi()

def __init__(
self,
json_serializer=None,
json_deserializer=None,
_add_declare_for_yql_stmt_vars=False,
**kwargs,
self,
json_serializer=None,
json_deserializer=None,
_add_declare_for_yql_stmt_vars=False,
**kwargs,
):
super().__init__(**kwargs)

Expand Down Expand Up @@ -295,9 +295,9 @@ def get_isolation_level(self, dbapi_connection: ydb_dbapi.Connection) -> str:
return dbapi_connection.get_isolation_level()

def set_ydb_request_settings(
self,
dbapi_connection: ydb_dbapi.Connection,
value: ydb.BaseRequestSettings,
self,
dbapi_connection: ydb_dbapi.Connection,
value: ydb.BaseRequestSettings,
) -> None:
dbapi_connection.set_ydb_request_settings(value)

Expand Down Expand Up @@ -332,10 +332,10 @@ def _handle_column_name(self, variable):
return "`" + variable + "`"

def _format_variables(
self,
statement: str,
parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]],
execute_many: bool,
self,
statement: str,
parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]],
execute_many: bool,
) -> Tuple[str, Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]:
formatted_statement = statement
formatted_parameters = None
Expand Down Expand Up @@ -370,7 +370,7 @@ def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types):
return f"{declarations}\n{statement}"

def __merge_parameters_values_and_types(
self, values: Mapping[str, Any], types: Mapping[str, Any], execute_many: bool
self, values: Mapping[str, Any], types: Mapping[str, Any], execute_many: bool
) -> Sequence[Mapping[str, ydb.TypedValue]]:
if isinstance(values, collections.abc.Mapping):
values = [values]
Expand All @@ -387,11 +387,11 @@ def __merge_parameters_values_and_types(
return result_list if execute_many else result_list[0]

def _prepare_ydb_query(
self,
statement: str,
context: Optional[DefaultExecutionContext] = None,
parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] = None,
execute_many: bool = False,
self,
statement: str,
context: Optional[DefaultExecutionContext] = None,
parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] = None,
execute_many: bool = False,
) -> Tuple[Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]:
is_ddl = context.isddl if context is not None else False

Expand All @@ -417,21 +417,21 @@ def do_ping(self, dbapi_connection: ydb_dbapi.Connection) -> bool:
return True

def do_executemany(
self,
cursor: ydb_dbapi.Cursor,
statement: str,
parameters: Optional[Sequence[Mapping[str, Any]]],
context: Optional[DefaultExecutionContext] = None,
self,
cursor: ydb_dbapi.Cursor,
statement: str,
parameters: Optional[Sequence[Mapping[str, Any]]],
context: Optional[DefaultExecutionContext] = None,
) -> None:
operation, parameters = self._prepare_ydb_query(statement, context, parameters, execute_many=True)
cursor.executemany(operation, parameters)

def do_execute(
self,
cursor: ydb_dbapi.Cursor,
statement: str,
parameters: Optional[Mapping[str, Any]] = None,
context: Optional[DefaultExecutionContext] = None,
self,
cursor: ydb_dbapi.Cursor,
statement: str,
parameters: Optional[Mapping[str, Any]] = None,
context: Optional[DefaultExecutionContext] = None,
) -> None:
operation, parameters = self._prepare_ydb_query(statement, context, parameters, execute_many=False)
is_ddl = context.isddl if context is not None else False
Expand All @@ -441,10 +441,44 @@ def do_execute(
cursor.execute(operation, parameters)





class AsyncCursor(AdaptedAsyncCursor):
def fetchone(self):
return self._cursor._fetchone_from_buffer()

def fetchmany(self, size=None):
size = size or self.arraysize
return self._cursor._fetchmany_from_buffer(size)

def fetchall(self):
return self._cursor._fetchall_from_buffer()

def close(self):
self._cursor._state = CursorStatus.closed


class AsyncConnection(AdaptedAsyncConnection):
def cursor(self):
return AsyncCursor(self._connection.cursor())


class AsyncYqlDialect(YqlDialect):
driver = "ydb_async"
is_async = True
supports_statement_cache = True

def __init__(self, json_serializer=None,
json_deserializer=None,
_add_declare_for_yql_stmt_vars=True,
**kwargs):
super().__init__(json_serializer=json_serializer, json_deserializer=json_deserializer,
_add_declare_for_yql_stmt_vars=_add_declare_for_yql_stmt_vars,
**kwargs)

def connect(self, *cargs, **cparams):
return AdaptedAsyncConnection(util.await_only(self.dbapi.async_connect(*cargs, **cparams)))
return AsyncConnection(util.await_only(self.dbapi.async_connect(*cargs, **cparams)))

def get_dialect_pool_class(self, url: URL) -> Type[Pool]:
return AsyncAdaptedQueuePool
Loading