diff --git a/.gitignore b/.gitignore index bd6ad26..adcbed1 100644 --- a/.gitignore +++ b/.gitignore @@ -130,3 +130,6 @@ dmypy.json # PyCharm .idea/ + +# VSCode +.vscode diff --git a/README.md b/README.md index a2fcef6..8722ff8 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ $ tox -e style Reformat code: ```bash +$ tox -e isort $ tox -e black-format ``` diff --git a/examples/example.py b/examples/example.py index 5d6a786..96f75bc 100644 --- a/examples/example.py +++ b/examples/example.py @@ -1,11 +1,10 @@ import datetime import logging -import sqlalchemy as sa -from sqlalchemy import orm, exc, sql -from sqlalchemy import Table, Column, Integer, String, Float, TIMESTAMP +import sqlalchemy as sa from fill_tables import fill_all_tables, to_days -from models import Base, Series, Episodes +from models import Base, Episodes, Series +from sqlalchemy import TIMESTAMP, Column, Float, Integer, String, Table, exc, orm, sql def describe_table(engine, name): diff --git a/examples/fill_tables.py b/examples/fill_tables.py index 5a9eb95..b047e65 100644 --- a/examples/fill_tables.py +++ b/examples/fill_tables.py @@ -1,7 +1,6 @@ import iso8601 - import sqlalchemy as sa -from models import Base, Series, Seasons, Episodes +from models import Base, Episodes, Seasons, Series def to_days(date): diff --git a/examples/models.py b/examples/models.py index a02349a..09f1882 100644 --- a/examples/models.py +++ b/examples/models.py @@ -1,7 +1,6 @@ import sqlalchemy.orm as orm from sqlalchemy import Column, Integer, Unicode - Base = orm.declarative_base() diff --git a/pyproject.toml b/pyproject.toml index 55ec8d7..85c3b07 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,2 +1,5 @@ [tool.black] line-length = 120 + +[tool.isort] +profile = "black" diff --git a/setup.cfg b/setup.cfg index d926c4e..650a160 100644 --- a/setup.cfg +++ b/setup.cfg @@ -7,3 +7,5 @@ profile_file=test/profiles.txt [db] default=yql+ydb://localhost:2136/local +ydb=yql+ydb://localhost:2136/local +ydb_async=yql+ydb_async://localhost:2136/local diff --git a/setup.py b/setup.py index 1cc3fb0..61878f5 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,8 @@ entry_points={ "sqlalchemy.dialects": [ "yql.ydb=ydb_sqlalchemy.sqlalchemy:YqlDialect", + "yql.ydb_async=ydb_sqlalchemy.sqlalchemy:AsyncYqlDialect", + "ydb_async=ydb_sqlalchemy.sqlalchemy:AsyncYqlDialect", "ydb=ydb_sqlalchemy.sqlalchemy:YqlDialect", "yql=ydb_sqlalchemy.sqlalchemy:YqlDialect", ] diff --git a/test-requirements.txt b/test-requirements.txt index d345613..21e0953 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -9,3 +9,5 @@ dockerpty==0.4.1 flake8==3.9.2 black==23.3.0 pytest-cov +pytest-asyncio +isort==5.13.2 diff --git a/test/conftest.py b/test/conftest.py index 0f8b014..5c0ed41 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -2,6 +2,8 @@ from sqlalchemy.dialects import registry registry.register("yql.ydb", "ydb_sqlalchemy.sqlalchemy", "YqlDialect") +registry.register("yql.ydb_async", "ydb_sqlalchemy.sqlalchemy", "AsyncYqlDialect") +registry.register("ydb_async", "ydb_sqlalchemy.sqlalchemy", "AsyncYqlDialect") registry.register("ydb", "ydb_sqlalchemy.sqlalchemy", "YqlDialect") registry.register("yql", "ydb_sqlalchemy.sqlalchemy", "YqlDialect") pytest.register_assert_rewrite("sqlalchemy.testing.assertions") diff --git a/test/test_core.py b/test/test_core.py index 1ff9f15..c2a1dc9 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -1,19 +1,18 @@ +import asyncio from datetime import date, datetime from decimal import Decimal from typing import NamedTuple import pytest - import sqlalchemy as sa -from sqlalchemy import Table, Column, Integer, Unicode, String -from sqlalchemy.testing.fixtures import TestBase, TablesTest, config - import ydb +from sqlalchemy import Column, Integer, String, Table, Unicode +from sqlalchemy.testing.fixtures import TablesTest, TestBase, config from ydb._grpc.v4.protos import ydb_common_pb2 -from ydb_sqlalchemy import dbapi, IsolationLevel -from ydb_sqlalchemy.sqlalchemy import types +from ydb_sqlalchemy import IsolationLevel, dbapi from ydb_sqlalchemy import sqlalchemy as ydb_sa +from ydb_sqlalchemy.sqlalchemy import types def clear_sql(stm): @@ -21,6 +20,8 @@ def clear_sql(stm): class TestText(TestBase): + __backend__ = True + def test_sa_text(self, connection): rs = connection.execute(sa.text("SELECT 1 AS value")) assert rs.fetchone() == (1,) @@ -38,6 +39,8 @@ def test_sa_text(self, connection): class TestCrud(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -82,6 +85,8 @@ def test_sa_crud(self, connection): class TestSimpleSelect(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -174,6 +179,8 @@ def test_sa_select_simple(self, connection): class TestTypes(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -211,6 +218,7 @@ def test_select_types(self, connection): class TestWithClause(TablesTest): + __backend__ = True run_create_tables = "each" @staticmethod @@ -223,10 +231,7 @@ def _create_table_and_get_desc(connection, metadata, **kwargs): ) table.create(connection) - session: ydb.Session = connection.connection.driver_connection.session_pool.acquire() - table_description = session.describe_table("/local/" + table.name) - connection.connection.driver_connection.session_pool.release(session) - return table_description + return connection.connection.driver_connection.describe(table.name) @pytest.mark.parametrize( "auto_partitioning_by_size,res", @@ -374,6 +379,8 @@ def test_several_keys(self, connection, metadata): class TestTransaction(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata: sa.MetaData): Table( @@ -462,6 +469,8 @@ def test_not_interactive_transaction( class TestTransactionIsolationLevel(TestBase): + __backend__ = True + class IsolationSettings(NamedTuple): ydb_mode: ydb.AbstractTransactionModeBuilder interactive: bool @@ -493,7 +502,10 @@ def test_connection_set(self, connection_no_trans: sa.Connection): class TestEngine(TestBase): - @pytest.fixture(scope="module") + __backend__ = True + __only_on__ = "yql+ydb" + + @pytest.fixture(scope="class") def ydb_driver(self): url = config.db_url driver = ydb.Driver(endpoint=f"grpc://{url.host}:{url.port}", database=url.database) @@ -505,13 +517,14 @@ def ydb_driver(self): driver.stop() - @pytest.fixture(scope="module") + @pytest.fixture(scope="class") def ydb_pool(self, ydb_driver): session_pool = ydb.SessionPool(ydb_driver, size=5, workers_threads_count=1) - yield session_pool - - session_pool.stop() + try: + yield session_pool + finally: + session_pool.stop() def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool}) @@ -544,7 +557,34 @@ def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): assert not ydb_driver._stopped +class TestAsyncEngine(TestEngine): + __only_on__ = "yql+ydb_async" + + @pytest.fixture(scope="class") + def ydb_driver(self): + loop = asyncio.get_event_loop() + url = config.db_url + driver = ydb.aio.Driver(endpoint=f"grpc://{url.host}:{url.port}", database=url.database) + try: + loop.run_until_complete(driver.wait(timeout=5, fail_fast=True)) + yield driver + finally: + loop.run_until_complete(driver.stop()) + + @pytest.fixture(scope="class") + def ydb_pool(self, ydb_driver): + session_pool = ydb.aio.SessionPool(ydb_driver, size=5) + + try: + yield session_pool + finally: + loop = asyncio.get_event_loop() + loop.run_until_complete(session_pool.stop()) + + class TestUpsert(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( @@ -644,6 +684,8 @@ def test_upsert_from_select(self, connection, metadata): class TestUpsertDoesNotReplaceInsert(TablesTest): + __backend__ = True + @classmethod def define_tables(cls, metadata): Table( diff --git a/test/test_inspect.py b/test/test_inspect.py index 1fe61b8..0d4c9a7 100644 --- a/test/test_inspect.py +++ b/test/test_inspect.py @@ -1,6 +1,5 @@ import sqlalchemy as sa - -from sqlalchemy import Table, Column, Integer, Unicode, Numeric +from sqlalchemy import Column, Integer, Numeric, Table, Unicode from sqlalchemy.testing.fixtures import TablesTest diff --git a/test/test_suite.py b/test/test_suite.py index dc61109..d853aaf 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -1,55 +1,86 @@ import pytest import sqlalchemy as sa import sqlalchemy.testing.suite.test_types - +from sqlalchemy.testing import is_false, is_true from sqlalchemy.testing.suite import * # noqa: F401, F403 - -from sqlalchemy.testing import is_true, is_false -from sqlalchemy.testing.suite import eq_, testing, inspect, provide_metadata, config, requirements, fixtures -from sqlalchemy.testing.suite import func, column, literal_column, select, exists -from sqlalchemy.testing.suite import MetaData, Column, Table, Integer, String - -from sqlalchemy.testing.suite.test_select import ( - ExistsTest as _ExistsTest, - LikeFunctionsTest as _LikeFunctionsTest, - CompoundSelectTest as _CompoundSelectTest, +from sqlalchemy.testing.suite import ( + Column, + Integer, + MetaData, + String, + Table, + column, + config, + eq_, + exists, + fixtures, + func, + inspect, + literal_column, + provide_metadata, + requirements, + select, + testing, +) +from sqlalchemy.testing.suite.test_ddl import ( + LongNameBlowoutTest as _LongNameBlowoutTest, +) +from sqlalchemy.testing.suite.test_deprecations import ( + DeprecatedCompoundSelectTest as _DeprecatedCompoundSelectTest, +) +from sqlalchemy.testing.suite.test_dialect import ( + DifficultParametersTest as _DifficultParametersTest, +) +from sqlalchemy.testing.suite.test_dialect import EscapingTest as _EscapingTest +from sqlalchemy.testing.suite.test_insert import ( + InsertBehaviorTest as _InsertBehaviorTest, ) from sqlalchemy.testing.suite.test_reflection import ( - HasTableTest as _HasTableTest, - HasIndexTest as _HasIndexTest, ComponentReflectionTest as _ComponentReflectionTest, - CompositeKeyReflectionTest as _CompositeKeyReflectionTest, +) +from sqlalchemy.testing.suite.test_reflection import ( ComponentReflectionTestExtra as _ComponentReflectionTestExtra, +) +from sqlalchemy.testing.suite.test_reflection import ( + CompositeKeyReflectionTest as _CompositeKeyReflectionTest, +) +from sqlalchemy.testing.suite.test_reflection import HasIndexTest as _HasIndexTest +from sqlalchemy.testing.suite.test_reflection import HasTableTest as _HasTableTest +from sqlalchemy.testing.suite.test_reflection import ( QuotedNameArgumentTest as _QuotedNameArgumentTest, ) +from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest +from sqlalchemy.testing.suite.test_select import ( + CompoundSelectTest as _CompoundSelectTest, +) +from sqlalchemy.testing.suite.test_select import ExistsTest as _ExistsTest +from sqlalchemy.testing.suite.test_select import ( + FetchLimitOffsetTest as _FetchLimitOffsetTest, +) +from sqlalchemy.testing.suite.test_select import JoinTest as _JoinTest +from sqlalchemy.testing.suite.test_select import LikeFunctionsTest as _LikeFunctionsTest +from sqlalchemy.testing.suite.test_select import OrderByLabelTest as _OrderByLabelTest +from sqlalchemy.testing.suite.test_types import BinaryTest as _BinaryTest +from sqlalchemy.testing.suite.test_types import DateTest as _DateTest from sqlalchemy.testing.suite.test_types import ( - IntegerTest as _IntegerTest, - NumericTest as _NumericTest, - BinaryTest as _BinaryTest, - TrueDivTest as _TrueDivTest, - TimeTest as _TimeTest, - StringTest as _StringTest, - NativeUUIDTest as _NativeUUIDTest, - TimeMicrosecondsTest as _TimeMicrosecondsTest, DateTimeCoercedToDateTimeTest as _DateTimeCoercedToDateTimeTest, - DateTest as _DateTest, +) +from sqlalchemy.testing.suite.test_types import ( DateTimeMicrosecondsTest as _DateTimeMicrosecondsTest, - DateTimeTest as _DateTimeTest, - TimestampMicrosecondsTest as _TimestampMicrosecondsTest, ) -from sqlalchemy.testing.suite.test_dialect import ( - EscapingTest as _EscapingTest, - DifficultParametersTest as _DifficultParametersTest, +from sqlalchemy.testing.suite.test_types import DateTimeTest as _DateTimeTest +from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest +from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest +from sqlalchemy.testing.suite.test_types import NumericTest as _NumericTest +from sqlalchemy.testing.suite.test_types import StringTest as _StringTest +from sqlalchemy.testing.suite.test_types import ( + TimeMicrosecondsTest as _TimeMicrosecondsTest, ) -from sqlalchemy.testing.suite.test_select import ( - JoinTest as _JoinTest, - OrderByLabelTest as _OrderByLabelTest, - FetchLimitOffsetTest as _FetchLimitOffsetTest, +from sqlalchemy.testing.suite.test_types import ( + TimestampMicrosecondsTest as _TimestampMicrosecondsTest, ) -from sqlalchemy.testing.suite.test_insert import InsertBehaviorTest as _InsertBehaviorTest -from sqlalchemy.testing.suite.test_ddl import LongNameBlowoutTest as _LongNameBlowoutTest -from sqlalchemy.testing.suite.test_results import RowFetchTest as _RowFetchTest -from sqlalchemy.testing.suite.test_deprecations import DeprecatedCompoundSelectTest as _DeprecatedCompoundSelectTest +from sqlalchemy.testing.suite.test_types import TimeTest as _TimeTest +from sqlalchemy.testing.suite.test_types import TrueDivTest as _TrueDivTest from ydb_sqlalchemy.sqlalchemy import types as ydb_sa_types diff --git a/test_dbapi/conftest.py b/test_dbapi/conftest.py deleted file mode 100644 index 7a9f5a3..0000000 --- a/test_dbapi/conftest.py +++ /dev/null @@ -1,10 +0,0 @@ -import pytest - -import ydb_sqlalchemy.dbapi as dbapi - - -@pytest.fixture(scope="module") -def connection(): - conn = dbapi.connect(host="localhost", port="2136", database="/local") - yield conn - conn.close() diff --git a/test_dbapi/test_dbapi.py b/test_dbapi/test_dbapi.py index ad354d5..08e4a9b 100644 --- a/test_dbapi/test_dbapi.py +++ b/test_dbapi/test_dbapi.py @@ -1,98 +1,196 @@ -import pytest +from contextlib import suppress +import pytest +import pytest_asyncio +import sqlalchemy.util as util import ydb + import ydb_sqlalchemy.dbapi as dbapi -from contextlib import suppress +class BaseDBApiTestSuit: + def _test_isolation_level_read_only(self, connection: dbapi.Connection, isolation_level: str, read_only: bool): + connection.cursor().execute( + dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True) + ) + connection.set_isolation_level(isolation_level) -def test_connection(connection): - connection.commit() - connection.rollback() + cursor = connection.cursor() - cur = connection.cursor() - with suppress(dbapi.DatabaseError): - cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) + connection.begin() - assert not connection.check_exists("/local/foo") - with pytest.raises(dbapi.ProgrammingError): - connection.describe("/local/foo") + query = dbapi.YdbQuery("UPSERT INTO foo(id) VALUES (1)") + if read_only: + with pytest.raises(dbapi.DatabaseError): + cursor.execute(query) + else: + cursor.execute(query) - cur.execute(dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True)) + connection.rollback() - assert connection.check_exists("/local/foo") + connection.cursor().execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) + connection.cursor().close() - col = connection.describe("/local/foo").columns[0] - assert col.name == "id" - assert col.type == ydb.PrimitiveType.Int64 + def _test_connection(self, connection: dbapi.Connection): + connection.commit() + connection.rollback() - cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) - cur.close() + cur = connection.cursor() + with suppress(dbapi.DatabaseError): + cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) + assert not connection.check_exists("/local/foo") + with pytest.raises(dbapi.ProgrammingError): + connection.describe("/local/foo") -def test_cursor_raw_query(connection): - cur = connection.cursor() - assert cur + cur.execute(dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True)) - with suppress(dbapi.DatabaseError): - cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) + assert connection.check_exists("/local/foo") - cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", is_ddl=True)) - - cur.execute( - dbapi.YdbQuery( - """ - DECLARE $data AS List>; - - INSERT INTO test SELECT id, text FROM AS_TABLE($data); - """, - parameters_types={ - "$data": ydb.ListType( - ydb.StructType() - .add_member("id", ydb.PrimitiveType.Int64) - .add_member("text", ydb.PrimitiveType.Utf8) - ) - }, - ), - { - "$data": [ - {"id": 17, "text": "seventeen"}, - {"id": 21, "text": "twenty one"}, - ] - }, - ) + col = connection.describe("/local/foo").columns[0] + assert col.name == "id" + assert col.type == ydb.PrimitiveType.Int64 - cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) + cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) + cur.close() + + def _test_cursor_raw_query(self, connection: dbapi.Connection): + cur = connection.cursor() + assert cur + + with suppress(dbapi.DatabaseError): + cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) + + cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", is_ddl=True)) + + cur.execute( + dbapi.YdbQuery( + """ + DECLARE $data AS List>; + + INSERT INTO test SELECT id, text FROM AS_TABLE($data); + """, + parameters_types={ + "$data": ydb.ListType( + ydb.StructType() + .add_member("id", ydb.PrimitiveType.Int64) + .add_member("text", ydb.PrimitiveType.Utf8) + ) + }, + ), + { + "$data": [ + {"id": 17, "text": "seventeen"}, + {"id": 21, "text": "twenty one"}, + ] + }, + ) - cur.close() + cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) + cur.close() -def test_errors(connection): - with pytest.raises(dbapi.InterfaceError): - dbapi.connect("localhost:2136", database="/local666") + def _test_errors(self, connection: dbapi.Connection): + with pytest.raises(dbapi.InterfaceError): + dbapi.YdbDBApi().connect("localhost:2136", database="/local666") - cur = connection.cursor() + cur = connection.cursor() - with suppress(dbapi.DatabaseError): - cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) + with suppress(dbapi.DatabaseError): + cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) - with pytest.raises(dbapi.DataError): - cur.execute(dbapi.YdbQuery("SELECT 18446744073709551616")) + with pytest.raises(dbapi.DataError): + cur.execute(dbapi.YdbQuery("SELECT 18446744073709551616")) - with pytest.raises(dbapi.DataError): - cur.execute(dbapi.YdbQuery("SELECT * FROM 拉屎")) + with pytest.raises(dbapi.DataError): + cur.execute(dbapi.YdbQuery("SELECT * FROM 拉屎")) - with pytest.raises(dbapi.DataError): - cur.execute(dbapi.YdbQuery("SELECT floor(5 / 2)")) + with pytest.raises(dbapi.DataError): + cur.execute(dbapi.YdbQuery("SELECT floor(5 / 2)")) - with pytest.raises(dbapi.ProgrammingError): - cur.execute(dbapi.YdbQuery("SELECT * FROM test")) + with pytest.raises(dbapi.ProgrammingError): + cur.execute(dbapi.YdbQuery("SELECT * FROM test")) - cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64, PRIMARY KEY (id))", is_ddl=True)) + cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64, PRIMARY KEY (id))", is_ddl=True)) - cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)")) - with pytest.raises(dbapi.IntegrityError): cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)")) + with pytest.raises(dbapi.IntegrityError): + cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)")) - cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) - cur.close() + cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) + cur.close() + + +class TestSyncConnection(BaseDBApiTestSuit): + @pytest.fixture + def sync_connection(self) -> dbapi.Connection: + conn = dbapi.YdbDBApi().connect(host="localhost", port="2136", database="/local") + try: + yield conn + finally: + conn.close() + + @pytest.mark.parametrize( + "isolation_level, read_only", + [ + (dbapi.IsolationLevel.SERIALIZABLE, False), + (dbapi.IsolationLevel.AUTOCOMMIT, False), + (dbapi.IsolationLevel.ONLINE_READONLY, True), + (dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True), + (dbapi.IsolationLevel.STALE_READONLY, True), + (dbapi.IsolationLevel.SNAPSHOT_READONLY, True), + ], + ) + def test_isolation_level_read_only(self, isolation_level: str, read_only: bool, sync_connection: dbapi.Connection): + self._test_isolation_level_read_only(sync_connection, isolation_level, read_only) + + def test_connection(self, sync_connection: dbapi.Connection): + self._test_connection(sync_connection) + + def test_cursor_raw_query(self, sync_connection: dbapi.Connection): + return self._test_cursor_raw_query(sync_connection) + + def test_errors(self, sync_connection: dbapi.Connection): + return self._test_errors(sync_connection) + + +class TestAsyncConnection(BaseDBApiTestSuit): + @pytest_asyncio.fixture + async def async_connection(self) -> dbapi.AsyncConnection: + def connect(): + return dbapi.YdbDBApi().async_connect(host="localhost", port="2136", database="/local") + + conn = await util.greenlet_spawn(connect) + try: + yield conn + finally: + await util.greenlet_spawn(conn.close) + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "isolation_level, read_only", + [ + (dbapi.IsolationLevel.SERIALIZABLE, False), + (dbapi.IsolationLevel.AUTOCOMMIT, False), + (dbapi.IsolationLevel.ONLINE_READONLY, True), + (dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True), + (dbapi.IsolationLevel.STALE_READONLY, True), + (dbapi.IsolationLevel.SNAPSHOT_READONLY, True), + ], + ) + async def test_isolation_level_read_only( + self, isolation_level: str, read_only: bool, async_connection: dbapi.AsyncConnection + ): + await util.greenlet_spawn(self._test_isolation_level_read_only, async_connection, isolation_level, read_only) + + @pytest.mark.asyncio + async def test_connection(self, async_connection: dbapi.AsyncConnection): + await util.greenlet_spawn(self._test_connection, async_connection) + + @pytest.mark.asyncio + async def test_cursor_raw_query(self, async_connection: dbapi.AsyncConnection): + await util.greenlet_spawn(self._test_cursor_raw_query, async_connection) + + @pytest.mark.asyncio + async def test_errors(self, async_connection: dbapi.AsyncConnection): + await util.greenlet_spawn(self._test_errors, async_connection) diff --git a/tox.ini b/tox.ini index a776ff6..41ee405 100644 --- a/tox.ini +++ b/tox.ini @@ -25,7 +25,7 @@ ignore_errors = True commands = docker-compose up -d python {toxinidir}/wait_container_ready.py - pytest -v test + pytest -v test --dbdriver ydb --dbdriver ydb_async pytest -v test_dbapi pytest -v ydb_sqlalchemy docker-compose down @@ -60,6 +60,11 @@ skip_install = true commands = black ydb_sqlalchemy examples test test_dbapi +[testenv:isort] +skip_install = true +commands = + isort ydb_sqlalchemy examples test test_dbapi + [testenv:style] ignore_errors = True commands = diff --git a/ydb_sqlalchemy/__init__.py b/ydb_sqlalchemy/__init__.py index 2e5fbab..7e39278 100644 --- a/ydb_sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/__init__.py @@ -1 +1,2 @@ from .dbapi import IsolationLevel # noqa: F401 +from .sqlalchemy import Upsert, types, upsert # noqa: F401 diff --git a/ydb_sqlalchemy/dbapi/__init__.py b/ydb_sqlalchemy/dbapi/__init__.py index f06e15f..f8fffe7 100644 --- a/ydb_sqlalchemy/dbapi/__init__.py +++ b/ydb_sqlalchemy/dbapi/__init__.py @@ -1,37 +1,43 @@ -from .connection import Connection, IsolationLevel # noqa: F401 -from .cursor import Cursor, YdbQuery # noqa: F401 +from .connection import AsyncConnection, Connection, IsolationLevel # noqa: F401 +from .cursor import AsyncCursor, Cursor, YdbQuery # noqa: F401 from .errors import ( - Warning, - Error, - InterfaceError, DatabaseError, DataError, - OperationalError, + Error, IntegrityError, + InterfaceError, InternalError, - ProgrammingError, NotSupportedError, + OperationalError, + ProgrammingError, + Warning, ) -apilevel = "1.0" - -threadsafety = 0 -paramstyle = "pyformat" +class YdbDBApi: + def __init__(self): + self.paramstyle = "pyformat" + self.threadsafety = 0 + self.apilevel = "1.0" + self._init_dbapi_attributes() -errors = ( - Warning, - Error, - InterfaceError, - DatabaseError, - DataError, - OperationalError, - IntegrityError, - InternalError, - ProgrammingError, - NotSupportedError, -) + def _init_dbapi_attributes(self): + for name, value in { + "Warning": Warning, + "Error": Error, + "InterfaceError": InterfaceError, + "DatabaseError": DatabaseError, + "DataError": DataError, + "OperationalError": OperationalError, + "IntegrityError": IntegrityError, + "InternalError": InternalError, + "ProgrammingError": ProgrammingError, + "NotSupportedError": NotSupportedError, + }.items(): + setattr(self, name, value) + def connect(self, *args, **kwargs) -> Connection: + return Connection(*args, **kwargs) -def connect(*args, **kwargs): - return Connection(*args, **kwargs) + def async_connect(self, *args, **kwargs) -> AsyncConnection: + return AsyncConnection(*args, **kwargs) diff --git a/ydb_sqlalchemy/dbapi/connection.py b/ydb_sqlalchemy/dbapi/connection.py index 43e6273..df73c4c 100644 --- a/ydb_sqlalchemy/dbapi/connection.py +++ b/ydb_sqlalchemy/dbapi/connection.py @@ -1,10 +1,12 @@ +import collections.abc import posixpath -from typing import Optional, NamedTuple, Any +from typing import Any, List, NamedTuple, Optional +import sqlalchemy.util as util import ydb -from .cursor import Cursor -from .errors import InterfaceError, ProgrammingError, DatabaseError, InternalError, NotSupportedError +from .cursor import AsyncCursor, Cursor +from .errors import InterfaceError, InternalError, NotSupportedError class IsolationLevel: @@ -17,6 +19,14 @@ class IsolationLevel: class Connection: + _await = staticmethod(util.await_only) + + _is_async = False + _ydb_driver_class = ydb.Driver + _ydb_session_pool_class = ydb.SessionPool + _ydb_table_client_class = ydb.TableClient + _cursor_class = Cursor + def __init__( self, host: str = "", @@ -31,37 +41,36 @@ def __init__( if "ydb_session_pool" in self.conn_kwargs: # Use session pool managed manually self._shared_session_pool = True self.session_pool: ydb.SessionPool = self.conn_kwargs.pop("ydb_session_pool") - self.driver = self.session_pool._pool_impl._driver - self.driver.table_client = ydb.TableClient(self.driver, self._get_table_client_settings()) + self.driver = ( + self.session_pool._driver + if hasattr(self.session_pool, "_driver") + else self.session_pool._pool_impl._driver + ) + self.driver.table_client = self._ydb_table_client_class(self.driver, self._get_table_client_settings()) else: self._shared_session_pool = False self.driver = self._create_driver() - self.session_pool = ydb.SessionPool(self.driver, size=5, workers_threads_count=1) + self.session_pool = self._ydb_session_pool_class(self.driver, size=5) self.interactive_transaction: bool = False # AUTOCOMMIT self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite() self.tx_context: Optional[ydb.TxContext] = None def cursor(self): - return Cursor(self.session_pool, self.tx_context) + return self._cursor_class(self.session_pool, self.tx_mode, self.tx_context) - def describe(self, table_path): - full_path = posixpath.join(self.database, table_path) - try: - return self.session_pool.retry_operation_sync(lambda session: session.describe_table(full_path)) - except ydb.issues.SchemeError as e: - raise ProgrammingError(e.message, e.issues, e.status) from e - except ydb.Error as e: - raise DatabaseError(e.message, e.issues, e.status) from e - except Exception as e: - raise DatabaseError(f"Failed to describe table {table_path}") from e + def describe(self, table_path: str) -> ydb.TableDescription: + abs_table_path = posixpath.join(self.database, table_path) + cursor = self.cursor() + return cursor.describe_table(abs_table_path) - def check_exists(self, table_path): - try: - self.driver.scheme_client.describe_path(table_path) - return True - except ydb.SchemeError: - return False + def check_exists(self, table_path: str) -> ydb.SchemeEntry: + cursor = self.cursor() + return cursor.check_exists(table_path) + + def get_table_names(self) -> List[str]: + cursor = self.cursor() + return cursor.get_table_names() def set_isolation_level(self, isolation_level: str): class IsolationSettings(NamedTuple): @@ -105,28 +114,34 @@ def get_isolation_level(self) -> str: def begin(self): self.tx_context = None if self.interactive_transaction: - session = self.session_pool.acquire(blocking=True) + session = self._maybe_await(self.session_pool.acquire) self.tx_context = session.transaction(self.tx_mode) - self.tx_context.begin() + self._maybe_await(self.tx_context.begin) def commit(self): if self.tx_context and self.tx_context.tx_id: - self.tx_context.commit() - self.session_pool.release(self.tx_context.session) + self._maybe_await(self.tx_context.commit) + self._maybe_await(self.session_pool.release, self.tx_context.session) self.tx_context = None def rollback(self): if self.tx_context and self.tx_context.tx_id: - self.tx_context.rollback() - self.session_pool.release(self.tx_context.session) + self._maybe_await(self.tx_context.rollback) + self._maybe_await(self.session_pool.release, self.tx_context.session) self.tx_context = None def close(self): self.rollback() if not self._shared_session_pool: - self.session_pool.stop() + self._maybe_await(self.session_pool.stop) self._stop_driver() + @classmethod + def _maybe_await(cls, callee: collections.abc.Callable, *args, **kwargs) -> Any: + if cls._is_async: + return cls._await(callee(*args, **kwargs)) + return callee(*args, **kwargs) + def _get_table_client_settings(self) -> ydb.TableClientSettings: return ( ydb.TableClientSettings() @@ -143,15 +158,23 @@ def _create_driver(self): database=self.database, table_client_settings=self._get_table_client_settings(), ) - driver = ydb.Driver(driver_config) + driver = self._ydb_driver_class(driver_config) try: - driver.wait(timeout=5, fail_fast=True) + self._maybe_await(driver.wait, timeout=5, fail_fast=True) except ydb.Error as e: - raise InterfaceError(e.message, e.issues, e.status) from e + raise InterfaceError(e.message, original_error=e) from e except Exception as e: - driver.stop() + self._maybe_await(driver.stop) raise InterfaceError(f"Failed to connect to YDB, details {driver.discovery_debug_details()}") from e return driver def _stop_driver(self): - self.driver.stop() + self._maybe_await(self.driver.stop) + + +class AsyncConnection(Connection): + _is_async = True + _ydb_driver_class = ydb.aio.Driver + _ydb_session_pool_class = ydb.aio.SessionPool + _ydb_table_client_class = ydb.aio.table.TableClient + _cursor_class = AsyncCursor diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py index 4ae9565..27e6593 100644 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ b/ydb_sqlalchemy/dbapi/cursor.py @@ -1,18 +1,22 @@ +import collections.abc import dataclasses +import functools import itertools import logging -from typing import Any, Mapping, Optional, Sequence, Union, Dict, Callable +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union import ydb +import ydb.aio +from sqlalchemy import util from .errors import ( - InternalError, - IntegrityError, - DataError, DatabaseError, - ProgrammingError, - OperationalError, + DataError, + IntegrityError, + InternalError, NotSupportedError, + OperationalError, + ProgrammingError, ) logger = logging.getLogger(__name__) @@ -31,19 +35,74 @@ class YdbQuery: is_ddl: bool = False -class Cursor(object): +def _handle_ydb_errors(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: + raise IntegrityError(e.message, original_error=e) from e + except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: + raise NotSupportedError(e.message, original_error=e) from e + except (ydb.issues.BadRequest, ydb.issues.SchemeError) as e: + raise ProgrammingError(e.message, original_error=e) from e + except ( + ydb.issues.TruncatedResponseError, + ydb.issues.ConnectionError, + ydb.issues.Aborted, + ydb.issues.Unavailable, + ydb.issues.Overloaded, + ydb.issues.Undetermined, + ydb.issues.Timeout, + ydb.issues.Cancelled, + ydb.issues.SessionBusy, + ydb.issues.SessionExpired, + ydb.issues.SessionPoolEmpty, + ) as e: + raise OperationalError(e.message, original_error=e) from e + except ydb.issues.GenericError as e: + raise DataError(e.message, original_error=e) from e + except ydb.issues.InternalError as e: + raise InternalError(e.message, original_error=e) from e + except ydb.Error as e: + raise DatabaseError(e.message, original_error=e) from e + except Exception as e: + raise DatabaseError("Failed to execute query") from e + + return wrapper + + +class Cursor: def __init__( self, - session_pool: ydb.SessionPool, + session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool], + tx_mode: ydb.AbstractTransactionModeBuilder, tx_context: Optional[ydb.BaseTxContext] = None, ): self.session_pool = session_pool + self.tx_mode = tx_mode self.tx_context = tx_context self.description = None self.arraysize = 1 self.rows = None self._rows_prefetched = None + @_handle_ydb_errors + def describe_table(self, abs_table_path: str) -> ydb.TableDescription: + return self._retry_operation_in_pool(self._describe_table, abs_table_path) + + def check_exists(self, table_path: str) -> bool: + try: + self._retry_operation_in_pool(self._describe_path, table_path) + return True + except ydb.SchemeError: + return False + + @_handle_ydb_errors + def get_table_names(self) -> List[str]: + directory: ydb.Directory = self._retry_operation_in_pool(self._list_directory) + return [child.name for child in directory.children if child.is_table()] + def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = None): if operation.is_ddl or not operation.parameters_types: query = operation.yql_text @@ -54,12 +113,9 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = logger.info("execute sql: %s, params: %s", query, parameters) if is_ddl: - chunks = self.session_pool.retry_operation_sync(self._execute_ddl, None, query) + chunks = self._execute_ddl(query) else: - if self.tx_context: - chunks = self._execute_dml(self.tx_context.session, query, parameters, self.tx_context) - else: - chunks = self.session_pool.retry_operation_sync(self._execute_dml, None, query, parameters) + chunks = self._execute_dml(query, parameters) rows = self._rows_iterable(chunks) # Prefetch the description: @@ -74,57 +130,69 @@ def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = self.rows = rows - @classmethod + @_handle_ydb_errors def _execute_dml( - cls, - session: ydb.Session, - query: ydb.DataQuery, - parameters: Optional[Mapping[str, Any]] = None, - tx_context: Optional[ydb.BaseTxContext] = None, + self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None ) -> ydb.convert.ResultSets: prepared_query = query if isinstance(query, str) and parameters: - prepared_query = session.prepare(query) + if self.tx_context: + prepared_query = self._run_operation_in_session(self._prepare, query) + else: + prepared_query = self._retry_operation_in_pool(self._prepare, query) - if tx_context: - return cls._handle_ydb_errors(tx_context.execute, prepared_query, parameters) + if self.tx_context: + return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters) - return cls._handle_ydb_errors(session.transaction().execute, prepared_query, parameters, commit_tx=True) + return self._retry_operation_in_pool(self._execute_in_session, self.tx_mode, prepared_query, parameters) - @classmethod - def _execute_ddl(cls, session: ydb.Session, query: str) -> ydb.convert.ResultSets: - return cls._handle_ydb_errors(session.execute_scheme, query) + @_handle_ydb_errors + def _execute_ddl(self, query: str) -> ydb.convert.ResultSets: + return self._retry_operation_in_pool(self._execute_scheme, query) @staticmethod - def _handle_ydb_errors(callee: Callable, *args, **kwargs) -> Any: - try: - return callee(*args, **kwargs) - except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: - raise IntegrityError(e.message, e.issues, e.status) from e - except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: - raise NotSupportedError(e.message, e.issues, e.status) from e - except (ydb.issues.BadRequest, ydb.issues.SchemeError) as e: - raise ProgrammingError(e.message, e.issues, e.status) from e - except ( - ydb.issues.TruncatedResponseError, - ydb.issues.ConnectionError, - ydb.issues.Aborted, - ydb.issues.Unavailable, - ydb.issues.Overloaded, - ydb.issues.Undetermined, - ydb.issues.Timeout, - ydb.issues.Cancelled, - ydb.issues.SessionBusy, - ydb.issues.SessionExpired, - ydb.issues.SessionPoolEmpty, - ) as e: - raise OperationalError(e.message, e.issues, e.status) from e - except ydb.issues.GenericError as e: - raise DataError(e.message, e.issues, e.status) from e - except ydb.issues.InternalError as e: - raise InternalError(e.message, e.issues, e.status) from e - except ydb.Error as e: - raise DatabaseError(e.message, e.issues, e.status) from e + def _execute_scheme(session: ydb.Session, query: str) -> ydb.convert.ResultSets: + return session.execute_scheme(query) + + @staticmethod + def _describe_table(session: ydb.Session, abs_table_path: str) -> ydb.TableDescription: + return session.describe_table(abs_table_path) + + @staticmethod + def _describe_path(session: ydb.Session, table_path: str) -> ydb.SchemeEntry: + return session._driver.scheme_client.describe_path(table_path) + + @staticmethod + def _list_directory(session: ydb.Session) -> ydb.Directory: + return session._driver.scheme_client.list_directory(session._driver._driver_config.database) + + @staticmethod + def _prepare(session: ydb.Session, query: str) -> ydb.DataQuery: + return session.prepare(query) + + @staticmethod + def _execute_in_tx( + tx_context: ydb.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + ) -> ydb.convert.ResultSets: + return tx_context.execute(prepared_query, parameters, commit_tx=False) + + @staticmethod + def _execute_in_session( + session: ydb.Session, + tx_mode: ydb.AbstractTransactionModeBuilder, + prepared_query: ydb.DataQuery, + parameters: Optional[Mapping[str, Any]], + ) -> ydb.convert.ResultSets: + return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) + + def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs): + return callee(self.tx_context, *args, **kwargs) + + def _run_operation_in_session(self, callee: collections.abc.Callable, *args, **kwargs): + return callee(self.tx_context.session, *args, **kwargs) + + def _retry_operation_in_pool(self, callee: collections.abc.Callable, *args, **kwargs): + return self.session_pool.retry_operation_sync(callee, None, *args, **kwargs) def _rows_iterable(self, chunks_iterable: ydb.convert.ResultSets): try: @@ -146,7 +214,7 @@ def _rows_iterable(self, chunks_iterable: ydb.convert.ResultSets): # of this PEP to return a sequence: https://www.python.org/dev/peps/pep-0249/#fetchmany yield row[::] except ydb.Error as e: - raise DatabaseError(e.message, e.issues, e.status) from e + raise DatabaseError(e.message, original_error=e) from e def _ensure_prefetched(self): if self.rows is not None and self._rows_prefetched is None: @@ -186,3 +254,51 @@ def close(self): @property def rowcount(self): return len(self._ensure_prefetched()) + + +class AsyncCursor(Cursor): + _await = staticmethod(util.await_only) + + @staticmethod + async def _describe_table(session: ydb.aio.table.Session, abs_table_path: str) -> ydb.TableDescription: + return await session.describe_table(abs_table_path) + + @staticmethod + async def _describe_path(session: ydb.aio.table.Session, table_path: str) -> ydb.SchemeEntry: + return await session._driver.scheme_client.describe_path(table_path) + + @staticmethod + async def _list_directory(session: ydb.aio.table.Session) -> ydb.Directory: + return await session._driver.scheme_client.list_directory(session._driver._driver_config.database) + + @staticmethod + async def _execute_scheme(session: ydb.aio.table.Session, query: str) -> ydb.convert.ResultSets: + return await session.execute_scheme(query) + + @staticmethod + async def _prepare(session: ydb.aio.table.Session, query: str) -> ydb.DataQuery: + return await session.prepare(query) + + @staticmethod + async def _execute_in_tx( + tx_context: ydb.aio.table.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + ) -> ydb.convert.ResultSets: + return await tx_context.execute(prepared_query, parameters, commit_tx=False) + + @staticmethod + async def _execute_in_session( + session: ydb.aio.table.Session, + tx_mode: ydb.AbstractTransactionModeBuilder, + prepared_query: ydb.DataQuery, + parameters: Optional[Mapping[str, Any]], + ) -> ydb.convert.ResultSets: + return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) + + def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs): + return self._await(callee(self.tx_context, *args, **kwargs)) + + def _run_operation_in_session(self, callee: collections.abc.Coroutine, *args, **kwargs): + return self._await(callee(self.tx_context.session, *args, **kwargs)) + + def _retry_operation_in_pool(self, callee: collections.abc.Coroutine, *args, **kwargs): + return self._await(self.session_pool.retry_operation(callee, *args, **kwargs)) diff --git a/ydb_sqlalchemy/dbapi/errors.py b/ydb_sqlalchemy/dbapi/errors.py index a67c0ce..70b55eb 100644 --- a/ydb_sqlalchemy/dbapi/errors.py +++ b/ydb_sqlalchemy/dbapi/errors.py @@ -1,15 +1,27 @@ +from typing import Optional, List + +import ydb +from google.protobuf.message import Message + + class Warning(Exception): pass class Error(Exception): - def __init__(self, message, issues=None, status=None): + def __init__( + self, + message: str, + original_error: Optional[ydb.Error] = None, + ): super(Error, self).__init__(message) - pretty_issues = _pretty_issues(issues) - self.issues = issues - self.message = pretty_issues or message - self.status = status + self.original_error = original_error + if original_error: + pretty_issues = _pretty_issues(original_error.issues) + self.issues = original_error.issues + self.message = pretty_issues or message + self.status = original_error.status class InterfaceError(Error): @@ -44,7 +56,7 @@ class NotSupportedError(DatabaseError): pass -def _pretty_issues(issues): +def _pretty_issues(issues: List[Message]) -> str: if issues is None: return None @@ -56,7 +68,7 @@ def _pretty_issues(issues): return "\n" + "\n".join(children_messages) -def _get_messages(issue, max_depth=100, indent=2, depth=0, root=False): +def _get_messages(issue: Message, max_depth: int = 100, indent: int = 2, depth: int = 0, root: bool = False) -> str: if depth >= max_depth: return None diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 297d4ff..a3cb0e4 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -4,27 +4,27 @@ """ import collections import collections.abc -import ydb -import ydb_sqlalchemy.dbapi as dbapi -from ydb_sqlalchemy.dbapi.constants import YDB_KEYWORDS -from ydb_sqlalchemy.sqlalchemy.dml import Upsert +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import sqlalchemy as sa +import ydb +from sqlalchemy.engine import reflection +from sqlalchemy.engine.default import DefaultExecutionContext, StrCompileDialect from sqlalchemy.exc import CompileError, NoSuchTableError from sqlalchemy.sql import functions, literal_column from sqlalchemy.sql.compiler import ( - selectable, + DDLCompiler, IdentifierPreparer, - StrSQLTypeCompiler, StrSQLCompiler, - DDLCompiler, + StrSQLTypeCompiler, + selectable, ) from sqlalchemy.sql.elements import ClauseList -from sqlalchemy.engine import reflection -from sqlalchemy.engine.default import StrCompileDialect, DefaultExecutionContext from sqlalchemy.util.compat import inspect_getfullargspec -from typing import Any, Union, Mapping, Sequence, Optional, Tuple, List, Dict +import ydb_sqlalchemy.dbapi as dbapi +from ydb_sqlalchemy.dbapi.constants import YDB_KEYWORDS +from ydb_sqlalchemy.sqlalchemy.dml import Upsert from . import types @@ -479,7 +479,7 @@ class YqlDialect(StrCompileDialect): @classmethod def import_dbapi(cls: Any): - return dbapi + return dbapi.YdbDBApi() def _describe_table(self, connection, table_name, schema=None): if schema is not None: @@ -510,15 +510,12 @@ def get_columns(self, connection, table_name, schema=None, **kw): return as_compatible @reflection.cache - def get_table_names(self, connection, schema=None, **kw): + def get_table_names(self, connection, schema=None, **kw) -> List[str]: if schema: raise dbapi.NotSupportedError("unsupported on non empty schema") - driver = connection.connection.driver_connection.driver - db_path = driver._driver_config.database - children = driver.scheme_client.list_directory(db_path).children - - return [child.name for child in children if child.is_table()] + raw_conn = connection.connection + return raw_conn.get_table_names() @reflection.cache def has_table(self, connection, table_name, schema=None, **kwargs): @@ -552,6 +549,9 @@ def get_default_isolation_level(self, dbapi_conn: dbapi.Connection) -> str: def get_isolation_level(self, dbapi_connection: dbapi.Connection) -> str: return dbapi_connection.get_isolation_level() + def connect(self, *cargs, **cparams): + return self.loaded_dbapi.connect(*cargs, **cparams) + def do_begin(self, dbapi_connection: dbapi.Connection) -> None: dbapi_connection.begin() @@ -634,3 +634,12 @@ def do_execute( ) -> None: operation, parameters = self._make_ydb_operation(statement, context, parameters, execute_many=False) cursor.execute(operation, parameters) + + +class AsyncYqlDialect(YqlDialect): + driver = "ydb_async" + is_async = True + supports_statement_cache = False + + def connect(self, *cargs, **cparams): + return self.loaded_dbapi.async_connect(*cargs, **cparams) diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index 61fa5ca..8570e9a 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -1,6 +1,7 @@ -from sqlalchemy import exc, ColumnElement, ARRAY, types +from typing import Any, Mapping, Type, Union + +from sqlalchemy import ARRAY, ColumnElement, exc, types from sqlalchemy.sql import type_api -from typing import Mapping, Any, Union, Type class UInt32(types.Integer):