diff --git a/README.md b/README.md index f5796db..0961293 100644 --- a/README.md +++ b/README.md @@ -46,7 +46,7 @@ $ tox -e test-all Run specific test: ```bash -$ tox -e test -- test_dbapi/test_dbapi.py +$ tox -e test -- test/test_core.py ``` Check code style: diff --git a/docker-compose.yml b/docker-compose.yml index 13badc8..3ec3550 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,10 +1,11 @@ version: "3.3" services: ydb: - image: cr.yandex/yc/yandex-docker-local-ydb:trunk + image: ydbplatform/local-ydb:trunk restart: always ports: - "2136:2136" hostname: localhost environment: - YDB_USE_IN_MEMORY_PDISKS=true + - YDB_ENABLE_COLUMN_TABLES=true diff --git a/requirements.txt b/requirements.txt index 7408917..ec6276b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,3 @@ sqlalchemy >= 2.0.7, < 3.0.0 -ydb >= 3.11.3 +ydb >= 3.18.8 +ydb-dbapi==0.0.1b7 diff --git a/setup.py b/setup.py index 9c8e82b..c6e4617 100644 --- a/setup.py +++ b/setup.py @@ -26,7 +26,7 @@ classifiers=[ "Programming Language :: Python", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], keywords="SQLAlchemy YDB YQL", install_requires=requirements, # requirements.txt diff --git a/test-requirements.txt b/test-requirements.txt index 22b7a2c..d7ff076 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,6 +1,9 @@ -sqlalchemy==2.0.7 -ydb==3.11.3 +pyyaml==5.3.1 +greenlet +sqlalchemy==2.0.7 +ydb >= 3.18.8 +ydb-dbapi==0.0.1b7 requests<2.29 pytest==7.2.2 docker==6.0.1 diff --git a/test/test_core.py b/test/test_core.py index 5fdcb84..61129f8 100644 --- a/test/test_core.py +++ b/test/test_core.py @@ -33,7 +33,18 @@ def test_sa_text(self, connection): SELECT x, y FROM AS_TABLE(:data) """ ), - [{"data": [{"x": 2, "y": 1}, {"x": 3, "y": 2}]}], + [ + { + "data": ydb.TypedValue( + [{"x": 2, "y": 1}, {"x": 3, "y": 2}], + ydb.ListType( + ydb.StructType() + .add_member("x", ydb.PrimitiveType.Int64) + .add_member("y", ydb.PrimitiveType.Int64) + ), + ) + } + ], ) assert set(rs.fetchall()) == {(2, 1), (3, 2)} @@ -454,85 +465,6 @@ def test_several_keys(self, connection, metadata): assert desc.partitioning_settings.max_partitions_count == 5 -class TestScanQuery(TablesTest): - __backend__ = True - - @classmethod - def define_tables(cls, metadata: sa.MetaData): - Table( - "test", - metadata, - Column("id", Integer, primary_key=True), - ) - - @classmethod - def insert_data(cls, connection: sa.Connection): - table = cls.tables.test - for i in range(50): - connection.execute(ydb_sa.upsert(table).values([{"id": i * 1000 + j} for j in range(1000)])) - - def test_characteristic(self): - engine = self.bind.execution_options() - - with engine.connect() as connection: - default_options = connection.get_execution_options() - - with engine.connect() as connection: - connection.execution_options(ydb_scan_query=True) - options_after_set = connection.get_execution_options() - - with engine.connect() as connection: - options_after_reset = connection.get_execution_options() - - assert "ydb_scan_query" not in default_options - assert options_after_set["ydb_scan_query"] - assert "ydb_scan_query" not in options_after_reset - - def test_fetchmany(self, connection_no_trans: sa.Connection): - table = self.tables.test - stmt = sa.select(table).where(table.c.id % 2 == 0) - - connection_no_trans.execution_options(ydb_scan_query=True) - cursor = connection_no_trans.execute(stmt) - - assert cursor.cursor.use_scan_query - result = cursor.fetchmany(1000) # fetches only the first 5k rows - assert result == [(i,) for i in range(2000) if i % 2 == 0] - - def test_fetchall(self, connection_no_trans: sa.Connection): - table = self.tables.test - stmt = sa.select(table).where(table.c.id % 2 == 0) - - connection_no_trans.execution_options(ydb_scan_query=True) - cursor = connection_no_trans.execute(stmt) - - assert cursor.cursor.use_scan_query - result = cursor.fetchall() - assert result == [(i,) for i in range(50000) if i % 2 == 0] - - def test_begin_does_nothing(self, connection_no_trans: sa.Connection): - table = self.tables.test - connection_no_trans.execution_options(ydb_scan_query=True) - - with connection_no_trans.begin(): - cursor = connection_no_trans.execute(sa.select(table)) - - assert cursor.cursor.use_scan_query - assert cursor.cursor.tx_context is None - - def test_engine_option(self): - table = self.tables.test - engine = self.bind.execution_options(ydb_scan_query=True) - - with engine.begin() as connection: - cursor = connection.execute(sa.select(table)) - assert cursor.cursor.use_scan_query - - with engine.begin() as connection: - cursor = connection.execute(sa.select(table)) - assert cursor.cursor.use_scan_query - - class TestTransaction(TablesTest): __backend__ = True @@ -585,11 +517,11 @@ def test_interactive_transaction( connection_no_trans.execution_options(isolation_level=isolation_level) with connection_no_trans.begin(): - tx_id = dbapi_connection.tx_context.tx_id - assert tx_id is not None cursor1 = connection_no_trans.execute(sa.select(table)) + tx_id = dbapi_connection._tx_context.tx_id + assert tx_id is not None cursor2 = connection_no_trans.execute(sa.select(table)) - assert dbapi_connection.tx_context.tx_id == tx_id + assert dbapi_connection._tx_context.tx_id == tx_id assert set(cursor1.fetchall()) == {(5,), (6,)} assert set(cursor2.fetchall()) == {(5,), (6,)} @@ -614,10 +546,10 @@ def test_not_interactive_transaction( connection_no_trans.execution_options(isolation_level=isolation_level) with connection_no_trans.begin(): - assert dbapi_connection.tx_context is None + assert dbapi_connection._tx_context is None cursor1 = connection_no_trans.execute(sa.select(table)) cursor2 = connection_no_trans.execute(sa.select(table)) - assert dbapi_connection.tx_context is None + assert dbapi_connection._tx_context is None assert set(cursor1.fetchall()) == {(7,), (8,)} assert set(cursor2.fetchall()) == {(7,), (8,)} @@ -631,14 +563,14 @@ class IsolationSettings(NamedTuple): interactive: bool YDB_ISOLATION_SETTINGS_MAP = { - IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.SerializableReadWrite().name, False), - IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.SerializableReadWrite().name, True), - IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.OnlineReadOnly().name, False), + IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.QuerySerializableReadWrite().name, False), + IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.QuerySerializableReadWrite().name, True), + IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.QueryOnlineReadOnly().name, False), IsolationLevel.ONLINE_READONLY_INCONSISTENT: IsolationSettings( - ydb.OnlineReadOnly().with_allow_inconsistent_reads().name, False + ydb.QueryOnlineReadOnly().with_allow_inconsistent_reads().name, False ), - IsolationLevel.STALE_READONLY: IsolationSettings(ydb.StaleReadOnly().name, False), - IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly().name, True), + IsolationLevel.STALE_READONLY: IsolationSettings(ydb.QueryStaleReadOnly().name, False), + IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.QuerySnapshotReadOnly().name, True), } def test_connection_set(self, connection_no_trans: sa.Connection): @@ -647,13 +579,13 @@ def test_connection_set(self, connection_no_trans: sa.Connection): for sa_isolation_level, ydb_isolation_settings in self.YDB_ISOLATION_SETTINGS_MAP.items(): connection_no_trans.execution_options(isolation_level=sa_isolation_level) with connection_no_trans.begin(): - assert dbapi_connection.tx_mode.name == ydb_isolation_settings[0] + assert dbapi_connection._tx_mode.name == ydb_isolation_settings[0] assert dbapi_connection.interactive_transaction is ydb_isolation_settings[1] if dbapi_connection.interactive_transaction: - assert dbapi_connection.tx_context is not None - assert dbapi_connection.tx_context.tx_id is not None + assert dbapi_connection._tx_context is not None + # assert dbapi_connection._tx_context.tx_id is not None else: - assert dbapi_connection.tx_context is None + assert dbapi_connection._tx_context is None class TestEngine(TestBase): @@ -674,7 +606,7 @@ def ydb_driver(self): @pytest.fixture(scope="class") def ydb_pool(self, ydb_driver): - session_pool = ydb.SessionPool(ydb_driver, size=5, workers_threads_count=1) + session_pool = ydb.QuerySessionPool(ydb_driver, size=5) try: yield session_pool @@ -689,8 +621,8 @@ def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection - assert dbapi_conn1.session_pool is dbapi_conn2.session_pool - assert dbapi_conn1.driver is dbapi_conn2.driver + assert dbapi_conn1._session_pool is dbapi_conn2._session_pool + assert dbapi_conn1._driver is dbapi_conn2._driver engine1.dispose() engine2.dispose() @@ -704,8 +636,8 @@ def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool): dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection - assert dbapi_conn1.session_pool is dbapi_conn2.session_pool - assert dbapi_conn1.driver is dbapi_conn2.driver + assert dbapi_conn1._session_pool is dbapi_conn2._session_pool + assert dbapi_conn1._driver is dbapi_conn2._driver engine1.dispose() engine2.dispose() @@ -726,14 +658,15 @@ def ydb_driver(self): finally: loop.run_until_complete(driver.stop()) + @pytest.mark.asyncio @pytest.fixture(scope="class") def ydb_pool(self, ydb_driver): - session_pool = ydb.aio.SessionPool(ydb_driver, size=5) + loop = asyncio.get_event_loop() + session_pool = ydb.aio.QuerySessionPool(ydb_driver, size=5, loop=loop) try: yield session_pool finally: - loop = asyncio.get_event_loop() loop.run_until_complete(session_pool.stop()) @@ -742,9 +675,9 @@ class TestCredentials(TestBase): __only_on__ = "yql+ydb" @pytest.fixture(scope="class") - def table_client_settings(self): + def query_client_settings(self): yield ( - ydb.TableClientSettings() + ydb.QueryClientSettings() .with_native_date_in_result_sets(True) .with_native_datetime_in_result_sets(True) .with_native_timestamp_in_result_sets(True) @@ -753,7 +686,7 @@ def table_client_settings(self): ) @pytest.fixture(scope="class") - def driver_config_for_credentials(self, table_client_settings): + def driver_config_for_credentials(self, query_client_settings): url = config.db_url endpoint = f"grpc://{url.host}:{url.port}" database = url.database @@ -761,10 +694,10 @@ def driver_config_for_credentials(self, table_client_settings): yield ydb.DriverConfig( endpoint=endpoint, database=database, - table_client_settings=table_client_settings, + query_client_settings=query_client_settings, ) - def test_ydb_credentials_good(self, table_client_settings, driver_config_for_credentials): + def test_ydb_credentials_good(self, query_client_settings, driver_config_for_credentials): credentials_good = ydb.StaticCredentials( driver_config=driver_config_for_credentials, user="root", @@ -775,7 +708,7 @@ def test_ydb_credentials_good(self, table_client_settings, driver_config_for_cre result = conn.execute(sa.text("SELECT 1 as value")) assert result.fetchone() - def test_ydb_credentials_bad(self, table_client_settings, driver_config_for_credentials): + def test_ydb_credentials_bad(self, query_client_settings, driver_config_for_credentials): credentials_bad = ydb.StaticCredentials( driver_config=driver_config_for_credentials, user="root", diff --git a/test/test_suite.py b/test/test_suite.py index 4ff54a8..bf0bbad 100644 --- a/test/test_suite.py +++ b/test/test_suite.py @@ -506,6 +506,18 @@ def test_struct_type_bind_variable(self, connection): eq_(connection.scalar(stmt, {"struct": {"id": 1}}), 1) + def test_struct_type_bind_variable_text(self, connection): + rs = connection.execute( + sa.text("SELECT :struct.x + :struct.y").bindparams( + sa.bindparam( + key="struct", + type_=ydb_sa_types.StructType({"x": sa.Integer, "y": sa.Integer}), + value={"x": 1, "y": 2}, + ) + ) + ) + assert rs.scalar() == 3 + def test_from_as_table(self, connection): table = self.tables.container_types_test diff --git a/test_dbapi/__init__.py b/test_dbapi/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/test_dbapi/test_dbapi.py b/test_dbapi/test_dbapi.py deleted file mode 100644 index 08e4a9b..0000000 --- a/test_dbapi/test_dbapi.py +++ /dev/null @@ -1,196 +0,0 @@ -from contextlib import suppress - -import pytest -import pytest_asyncio -import sqlalchemy.util as util -import ydb - -import ydb_sqlalchemy.dbapi as dbapi - - -class BaseDBApiTestSuit: - def _test_isolation_level_read_only(self, connection: dbapi.Connection, isolation_level: str, read_only: bool): - connection.cursor().execute( - dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True) - ) - connection.set_isolation_level(isolation_level) - - cursor = connection.cursor() - - connection.begin() - - query = dbapi.YdbQuery("UPSERT INTO foo(id) VALUES (1)") - if read_only: - with pytest.raises(dbapi.DatabaseError): - cursor.execute(query) - else: - cursor.execute(query) - - connection.rollback() - - connection.cursor().execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) - connection.cursor().close() - - def _test_connection(self, connection: dbapi.Connection): - connection.commit() - connection.rollback() - - cur = connection.cursor() - with suppress(dbapi.DatabaseError): - cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) - - assert not connection.check_exists("/local/foo") - with pytest.raises(dbapi.ProgrammingError): - connection.describe("/local/foo") - - cur.execute(dbapi.YdbQuery("CREATE TABLE foo(id Int64 NOT NULL, PRIMARY KEY (id))", is_ddl=True)) - - assert connection.check_exists("/local/foo") - - col = connection.describe("/local/foo").columns[0] - assert col.name == "id" - assert col.type == ydb.PrimitiveType.Int64 - - cur.execute(dbapi.YdbQuery("DROP TABLE foo", is_ddl=True)) - cur.close() - - def _test_cursor_raw_query(self, connection: dbapi.Connection): - cur = connection.cursor() - assert cur - - with suppress(dbapi.DatabaseError): - cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) - - cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64 NOT NULL, text Utf8, PRIMARY KEY (id))", is_ddl=True)) - - cur.execute( - dbapi.YdbQuery( - """ - DECLARE $data AS List>; - - INSERT INTO test SELECT id, text FROM AS_TABLE($data); - """, - parameters_types={ - "$data": ydb.ListType( - ydb.StructType() - .add_member("id", ydb.PrimitiveType.Int64) - .add_member("text", ydb.PrimitiveType.Utf8) - ) - }, - ), - { - "$data": [ - {"id": 17, "text": "seventeen"}, - {"id": 21, "text": "twenty one"}, - ] - }, - ) - - cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) - - cur.close() - - def _test_errors(self, connection: dbapi.Connection): - with pytest.raises(dbapi.InterfaceError): - dbapi.YdbDBApi().connect("localhost:2136", database="/local666") - - cur = connection.cursor() - - with suppress(dbapi.DatabaseError): - cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) - - with pytest.raises(dbapi.DataError): - cur.execute(dbapi.YdbQuery("SELECT 18446744073709551616")) - - with pytest.raises(dbapi.DataError): - cur.execute(dbapi.YdbQuery("SELECT * FROM 拉屎")) - - with pytest.raises(dbapi.DataError): - cur.execute(dbapi.YdbQuery("SELECT floor(5 / 2)")) - - with pytest.raises(dbapi.ProgrammingError): - cur.execute(dbapi.YdbQuery("SELECT * FROM test")) - - cur.execute(dbapi.YdbQuery("CREATE TABLE test(id Int64, PRIMARY KEY (id))", is_ddl=True)) - - cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)")) - with pytest.raises(dbapi.IntegrityError): - cur.execute(dbapi.YdbQuery("INSERT INTO test(id) VALUES(1)")) - - cur.execute(dbapi.YdbQuery("DROP TABLE test", is_ddl=True)) - cur.close() - - -class TestSyncConnection(BaseDBApiTestSuit): - @pytest.fixture - def sync_connection(self) -> dbapi.Connection: - conn = dbapi.YdbDBApi().connect(host="localhost", port="2136", database="/local") - try: - yield conn - finally: - conn.close() - - @pytest.mark.parametrize( - "isolation_level, read_only", - [ - (dbapi.IsolationLevel.SERIALIZABLE, False), - (dbapi.IsolationLevel.AUTOCOMMIT, False), - (dbapi.IsolationLevel.ONLINE_READONLY, True), - (dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True), - (dbapi.IsolationLevel.STALE_READONLY, True), - (dbapi.IsolationLevel.SNAPSHOT_READONLY, True), - ], - ) - def test_isolation_level_read_only(self, isolation_level: str, read_only: bool, sync_connection: dbapi.Connection): - self._test_isolation_level_read_only(sync_connection, isolation_level, read_only) - - def test_connection(self, sync_connection: dbapi.Connection): - self._test_connection(sync_connection) - - def test_cursor_raw_query(self, sync_connection: dbapi.Connection): - return self._test_cursor_raw_query(sync_connection) - - def test_errors(self, sync_connection: dbapi.Connection): - return self._test_errors(sync_connection) - - -class TestAsyncConnection(BaseDBApiTestSuit): - @pytest_asyncio.fixture - async def async_connection(self) -> dbapi.AsyncConnection: - def connect(): - return dbapi.YdbDBApi().async_connect(host="localhost", port="2136", database="/local") - - conn = await util.greenlet_spawn(connect) - try: - yield conn - finally: - await util.greenlet_spawn(conn.close) - - @pytest.mark.asyncio - @pytest.mark.parametrize( - "isolation_level, read_only", - [ - (dbapi.IsolationLevel.SERIALIZABLE, False), - (dbapi.IsolationLevel.AUTOCOMMIT, False), - (dbapi.IsolationLevel.ONLINE_READONLY, True), - (dbapi.IsolationLevel.ONLINE_READONLY_INCONSISTENT, True), - (dbapi.IsolationLevel.STALE_READONLY, True), - (dbapi.IsolationLevel.SNAPSHOT_READONLY, True), - ], - ) - async def test_isolation_level_read_only( - self, isolation_level: str, read_only: bool, async_connection: dbapi.AsyncConnection - ): - await util.greenlet_spawn(self._test_isolation_level_read_only, async_connection, isolation_level, read_only) - - @pytest.mark.asyncio - async def test_connection(self, async_connection: dbapi.AsyncConnection): - await util.greenlet_spawn(self._test_connection, async_connection) - - @pytest.mark.asyncio - async def test_cursor_raw_query(self, async_connection: dbapi.AsyncConnection): - await util.greenlet_spawn(self._test_cursor_raw_query, async_connection) - - @pytest.mark.asyncio - async def test_errors(self, async_connection: dbapi.AsyncConnection): - await util.greenlet_spawn(self._test_errors, async_connection) diff --git a/tox.ini b/tox.ini index 41ee405..2a4c1ba 100644 --- a/tox.ini +++ b/tox.ini @@ -26,17 +26,8 @@ commands = docker-compose up -d python {toxinidir}/wait_container_ready.py pytest -v test --dbdriver ydb --dbdriver ydb_async - pytest -v test_dbapi pytest -v ydb_sqlalchemy - docker-compose down - -[testenv:test-dbapi] -ignore_errors = True -commands = - docker-compose up -d - python {toxinidir}/wait_container_ready.py - pytest -v {toxinidir}/test_dbapi - docker-compose down + docker-compose down -v [testenv:test-unit] commands = @@ -53,22 +44,22 @@ commands = [testenv:black] skip_install = true commands = - black --diff --check ydb_sqlalchemy examples test test_dbapi + black --diff --check ydb_sqlalchemy examples test [testenv:black-format] skip_install = true commands = - black ydb_sqlalchemy examples test test_dbapi + black ydb_sqlalchemy examples test [testenv:isort] skip_install = true commands = - isort ydb_sqlalchemy examples test test_dbapi + isort ydb_sqlalchemy examples test [testenv:style] ignore_errors = True commands = - flake8 ydb_sqlalchemy examples test test_dbapi + flake8 ydb_sqlalchemy examples test [flake8] show-source = true diff --git a/ydb_sqlalchemy/__init__.py b/ydb_sqlalchemy/__init__.py index d57dab0..55ade24 100644 --- a/ydb_sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/__init__.py @@ -1,3 +1,4 @@ from ._version import VERSION # noqa: F401 -from .dbapi import IsolationLevel # noqa: F401 +from ydb_dbapi import IsolationLevel # noqa: F401 from .sqlalchemy import Upsert, types, upsert # noqa: F401 +import ydb_dbapi as dbapi diff --git a/ydb_sqlalchemy/dbapi/__init__.py b/ydb_sqlalchemy/dbapi/__init__.py deleted file mode 100644 index f8fffe7..0000000 --- a/ydb_sqlalchemy/dbapi/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -from .connection import AsyncConnection, Connection, IsolationLevel # noqa: F401 -from .cursor import AsyncCursor, Cursor, YdbQuery # noqa: F401 -from .errors import ( - DatabaseError, - DataError, - Error, - IntegrityError, - InterfaceError, - InternalError, - NotSupportedError, - OperationalError, - ProgrammingError, - Warning, -) - - -class YdbDBApi: - def __init__(self): - self.paramstyle = "pyformat" - self.threadsafety = 0 - self.apilevel = "1.0" - self._init_dbapi_attributes() - - def _init_dbapi_attributes(self): - for name, value in { - "Warning": Warning, - "Error": Error, - "InterfaceError": InterfaceError, - "DatabaseError": DatabaseError, - "DataError": DataError, - "OperationalError": OperationalError, - "IntegrityError": IntegrityError, - "InternalError": InternalError, - "ProgrammingError": ProgrammingError, - "NotSupportedError": NotSupportedError, - }.items(): - setattr(self, name, value) - - def connect(self, *args, **kwargs) -> Connection: - return Connection(*args, **kwargs) - - def async_connect(self, *args, **kwargs) -> AsyncConnection: - return AsyncConnection(*args, **kwargs) diff --git a/ydb_sqlalchemy/dbapi/connection.py b/ydb_sqlalchemy/dbapi/connection.py deleted file mode 100644 index dc4ac97..0000000 --- a/ydb_sqlalchemy/dbapi/connection.py +++ /dev/null @@ -1,207 +0,0 @@ -import collections.abc -import posixpath -from typing import Any, List, NamedTuple, Optional - -import sqlalchemy.util as util -import ydb - -from .cursor import AsyncCursor, Cursor -from .errors import InterfaceError, InternalError, NotSupportedError - - -class IsolationLevel: - SERIALIZABLE = "SERIALIZABLE" - ONLINE_READONLY = "ONLINE READONLY" - ONLINE_READONLY_INCONSISTENT = "ONLINE READONLY INCONSISTENT" - STALE_READONLY = "STALE READONLY" - SNAPSHOT_READONLY = "SNAPSHOT READONLY" - AUTOCOMMIT = "AUTOCOMMIT" - - -class Connection: - _await = staticmethod(util.await_only) - - _is_async = False - _ydb_driver_class = ydb.Driver - _ydb_session_pool_class = ydb.SessionPool - _ydb_table_client_class = ydb.TableClient - _cursor_class = Cursor - - def __init__( - self, - host: str = "", - port: str = "", - database: str = "", - **conn_kwargs: Any, - ): - self.endpoint = f"grpc://{host}:{port}" - self.database = database - self.conn_kwargs = conn_kwargs - self.credentials = self.conn_kwargs.pop("credentials", None) - self.table_path_prefix = self.conn_kwargs.pop("ydb_table_path_prefix", "") - - if "ydb_session_pool" in self.conn_kwargs: # Use session pool managed manually - self._shared_session_pool = True - self.session_pool: ydb.SessionPool = self.conn_kwargs.pop("ydb_session_pool") - self.driver = ( - self.session_pool._driver - if hasattr(self.session_pool, "_driver") - else self.session_pool._pool_impl._driver - ) - self.driver.table_client = self._ydb_table_client_class(self.driver, self._get_table_client_settings()) - else: - self._shared_session_pool = False - self.driver = self._create_driver() - self.session_pool = self._ydb_session_pool_class(self.driver, size=5) - - self.interactive_transaction: bool = False # AUTOCOMMIT - self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite() - self.tx_context: Optional[ydb.TxContext] = None - self.use_scan_query: bool = False - self.request_settings: ydb.BaseRequestSettings = ydb.BaseRequestSettings() - - def cursor(self): - return self._cursor_class( - driver=self.driver, - session_pool=self.session_pool, - tx_mode=self.tx_mode, - tx_context=self.tx_context, - request_settings=self.request_settings, - use_scan_query=self.use_scan_query, - table_path_prefix=self.table_path_prefix, - ) - - def describe(self, table_path: str) -> ydb.TableDescription: - abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path) - cursor = self.cursor() - return cursor.describe_table(abs_table_path) - - def check_exists(self, table_path: str) -> bool: - abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path) - cursor = self.cursor() - return cursor.check_exists(abs_table_path) - - def get_table_names(self) -> List[str]: - abs_dir_path = posixpath.join(self.database, self.table_path_prefix) - cursor = self.cursor() - return [posixpath.relpath(path, abs_dir_path) for path in cursor.get_table_names(abs_dir_path)] - - def set_isolation_level(self, isolation_level: str): - class IsolationSettings(NamedTuple): - ydb_mode: ydb.AbstractTransactionModeBuilder - interactive: bool - - ydb_isolation_settings_map = { - IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.SerializableReadWrite(), interactive=False), - IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.SerializableReadWrite(), interactive=True), - IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.OnlineReadOnly(), interactive=False), - IsolationLevel.ONLINE_READONLY_INCONSISTENT: IsolationSettings( - ydb.OnlineReadOnly().with_allow_inconsistent_reads(), interactive=False - ), - IsolationLevel.STALE_READONLY: IsolationSettings(ydb.StaleReadOnly(), interactive=False), - IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly(), interactive=True), - } - ydb_isolation_settings = ydb_isolation_settings_map[isolation_level] - if self.tx_context and self.tx_context.tx_id: - raise InternalError("Failed to set transaction mode: transaction is already began") - self.tx_mode = ydb_isolation_settings.ydb_mode - self.interactive_transaction = ydb_isolation_settings.interactive - - def get_isolation_level(self) -> str: - if self.tx_mode.name == ydb.SerializableReadWrite().name: - if self.interactive_transaction: - return IsolationLevel.SERIALIZABLE - else: - return IsolationLevel.AUTOCOMMIT - elif self.tx_mode.name == ydb.OnlineReadOnly().name: - if self.tx_mode.settings.allow_inconsistent_reads: - return IsolationLevel.ONLINE_READONLY_INCONSISTENT - else: - return IsolationLevel.ONLINE_READONLY - elif self.tx_mode.name == ydb.StaleReadOnly().name: - return IsolationLevel.STALE_READONLY - elif self.tx_mode.name == ydb.SnapshotReadOnly().name: - return IsolationLevel.SNAPSHOT_READONLY - else: - raise NotSupportedError(f"{self.tx_mode.name} is not supported") - - def set_ydb_scan_query(self, value: bool) -> None: - self.use_scan_query = value - - def get_ydb_scan_query(self) -> bool: - return self.use_scan_query - - def set_ydb_request_settings(self, value: ydb.BaseRequestSettings) -> None: - self.request_settings = value - - def get_ydb_request_settings(self) -> ydb.BaseRequestSettings: - return self.request_settings - - def begin(self): - self.tx_context = None - if self.interactive_transaction and not self.use_scan_query: - session = self._maybe_await(self.session_pool.acquire) - self.tx_context = session.transaction(self.tx_mode) - self._maybe_await(self.tx_context.begin) - - def commit(self): - if self.tx_context and self.tx_context.tx_id: - self._maybe_await(self.tx_context.commit) - self._maybe_await(self.session_pool.release, self.tx_context.session) - self.tx_context = None - - def rollback(self): - if self.tx_context and self.tx_context.tx_id: - self._maybe_await(self.tx_context.rollback) - self._maybe_await(self.session_pool.release, self.tx_context.session) - self.tx_context = None - - def close(self): - self.rollback() - if not self._shared_session_pool: - self._maybe_await(self.session_pool.stop) - self._stop_driver() - - @classmethod - def _maybe_await(cls, callee: collections.abc.Callable, *args, **kwargs) -> Any: - if cls._is_async: - return cls._await(callee(*args, **kwargs)) - return callee(*args, **kwargs) - - def _get_table_client_settings(self) -> ydb.TableClientSettings: - return ( - ydb.TableClientSettings() - .with_native_date_in_result_sets(True) - .with_native_datetime_in_result_sets(True) - .with_native_timestamp_in_result_sets(True) - .with_native_interval_in_result_sets(True) - .with_native_json_in_result_sets(False) - ) - - def _create_driver(self): - driver_config = ydb.DriverConfig( - endpoint=self.endpoint, - database=self.database, - table_client_settings=self._get_table_client_settings(), - credentials=self.credentials, - ) - driver = self._ydb_driver_class(driver_config) - try: - self._maybe_await(driver.wait, timeout=5, fail_fast=True) - except ydb.Error as e: - raise InterfaceError(e.message, original_error=e) from e - except Exception as e: - self._maybe_await(driver.stop) - raise InterfaceError(f"Failed to connect to YDB, details {driver.discovery_debug_details()}") from e - return driver - - def _stop_driver(self): - self._maybe_await(self.driver.stop) - - -class AsyncConnection(Connection): - _is_async = True - _ydb_driver_class = ydb.aio.Driver - _ydb_session_pool_class = ydb.aio.SessionPool - _ydb_table_client_class = ydb.aio.table.TableClient - _cursor_class = AsyncCursor diff --git a/ydb_sqlalchemy/dbapi/constants.py b/ydb_sqlalchemy/dbapi/constants.py deleted file mode 100644 index a27aef1..0000000 --- a/ydb_sqlalchemy/dbapi/constants.py +++ /dev/null @@ -1,218 +0,0 @@ -YDB_KEYWORDS = { - "abort", - "action", - "add", - "after", - "all", - "alter", - "analyze", - "and", - "ansi", - "any", - "array", - "as", - "asc", - "assume", - "async", - "attach", - "autoincrement", - "before", - "begin", - "bernoulli", - "between", - "bitcast", - "by", - "cascade", - "case", - "cast", - "changefeed", - "check", - "collate", - "column", - "columns", - "commit", - "compact", - "conditional", - "conflict", - "constraint", - "consumer", - "cover", - "create", - "cross", - "cube", - "current", - "current_date", - "current_time", - "current_timestamp", - "data", - "database", - "decimal", - "declare", - "default", - "deferrable", - "deferred", - "define", - "delete", - "desc", - "detach", - "disable", - "discard", - "distinct", - "do", - "drop", - "each", - "else", - "empty", - "empty_action", - "encrypted", - "end", - "erase", - "error", - "escape", - "evaluate", - "except", - "exclude", - "exclusion", - "exclusive", - "exists", - "explain", - "export", - "external", - "fail", - "family", - "filter", - "flatten", - "following", - "for", - "foreign", - "from", - "full", - "function", - "glob", - "group", - "grouping", - "groups", - "hash", - "having", - "hop", - "if", - "ignore", - "ilike", - "immediate", - "import", - "in", - "index", - "indexed", - "inherits", - "initially", - "inner", - "insert", - "instead", - "intersect", - "into", - "is", - "isnull", - "join", - "json_exists", - "json_query", - "json_value", - "key", - "left", - "like", - "limit", - "local", - "match", - "natural", - "no", - "not", - "notnull", - "null", - "nulls", - "object", - "of", - "offset", - "on", - "only", - "or", - "order", - "others", - "outer", - "over", - "partition", - "passing", - "password", - "plan", - "pragma", - "preceding", - "presort", - "primary", - "process", - "raise", - "range", - "reduce", - "references", - "regexp", - "reindex", - "release", - "rename", - "replace", - "replication", - "reset", - "respect", - "restrict", - "result", - "return", - "returning", - "revert", - "right", - "rlike", - "rollback", - "rollup", - "row", - "rows", - "sample", - "savepoint", - "schema", - "select", - "semi", - "sets", - "source", - "stream", - "subquery", - "symbols", - "sync", - "system", - "table", - "tablesample", - "tablestore", - "temp", - "temporary", - "then", - "ties", - "to", - "topic", - "transaction", - "trigger", - "type", - "unbounded", - "unconditional", - "union", - "unique", - "unknown", - "update", - "upsert", - "use", - "user", - "using", - "vacuum", - "values", - "view", - "virtual", - "when", - "where", - "window", - "with", - "without", - "wrapper", - "xor", -} diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py deleted file mode 100644 index c104356..0000000 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ /dev/null @@ -1,435 +0,0 @@ -import collections.abc -import dataclasses -import functools -import hashlib -import itertools -import posixpath -from collections.abc import AsyncIterator -from typing import Any, Dict, Generator, List, Mapping, Optional, Sequence, Union - -import ydb -import ydb.aio -from sqlalchemy import util - -from .errors import ( - DatabaseError, - DataError, - IntegrityError, - InternalError, - NotSupportedError, - OperationalError, - ProgrammingError, -) -from .tracing import maybe_get_current_trace_id - - -def get_column_type(type_obj: Any) -> str: - return str(ydb.convert.type_to_native(type_obj)) - - -@dataclasses.dataclass -class YdbQuery: - yql_text: str - parameters_types: Dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]] = dataclasses.field( - default_factory=dict - ) - is_ddl: bool = False - - -def _handle_ydb_errors(func): - @functools.wraps(func) - def wrapper(*args, **kwargs): - try: - return func(*args, **kwargs) - except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: - raise IntegrityError(e.message, original_error=e) from e - except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: - raise NotSupportedError(e.message, original_error=e) from e - except (ydb.issues.BadRequest, ydb.issues.SchemeError) as e: - raise ProgrammingError(e.message, original_error=e) from e - except ( - ydb.issues.TruncatedResponseError, - ydb.issues.ConnectionError, - ydb.issues.Aborted, - ydb.issues.Unavailable, - ydb.issues.Overloaded, - ydb.issues.Undetermined, - ydb.issues.Timeout, - ydb.issues.Cancelled, - ydb.issues.SessionBusy, - ydb.issues.SessionExpired, - ydb.issues.SessionPoolEmpty, - ) as e: - raise OperationalError(e.message, original_error=e) from e - except ydb.issues.GenericError as e: - raise DataError(e.message, original_error=e) from e - except ydb.issues.InternalError as e: - raise InternalError(e.message, original_error=e) from e - except ydb.Error as e: - raise DatabaseError(e.message, original_error=e) from e - except Exception as e: - raise DatabaseError("Failed to execute query") from e - - return wrapper - - -class Cursor: - def __init__( - self, - driver: Union[ydb.Driver, ydb.aio.Driver], - session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool], - tx_mode: ydb.AbstractTransactionModeBuilder, - request_settings: ydb.BaseRequestSettings, - tx_context: Optional[ydb.BaseTxContext] = None, - use_scan_query: bool = False, - table_path_prefix: str = "", - ): - self.driver = driver - self.session_pool = session_pool - self.tx_mode = tx_mode - self.request_settings = request_settings - self.tx_context = tx_context - self.use_scan_query = use_scan_query - self.root_directory = table_path_prefix - self.description = None - self.arraysize = 1 - self.rows = None - self._rows_prefetched = None - - @_handle_ydb_errors - def describe_table(self, abs_table_path: str) -> ydb.TableDescription: - settings = self._get_request_settings() - return self._retry_operation_in_pool(self._describe_table, abs_table_path, settings) - - def check_exists(self, abs_table_path: str) -> bool: - settings = self._get_request_settings() - try: - self._retry_operation_in_pool(self._describe_path, abs_table_path, settings) - return True - except ydb.SchemeError: - return False - - @_handle_ydb_errors - def get_table_names(self, abs_dir_path: str) -> List[str]: - settings = self._get_request_settings() - directory: ydb.Directory = self._retry_operation_in_pool(self._list_directory, abs_dir_path, settings) - result = [] - for child in directory.children: - child_abs_path = posixpath.join(abs_dir_path, child.name) - if child.is_table(): - result.append(child_abs_path) - elif child.is_directory() and not child.name.startswith("."): - result.extend(self.get_table_names(child_abs_path)) - return result - - def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = None): - query = self._get_ydb_query(operation) - - if operation.is_ddl: - chunks = self._execute_ddl(query) - elif self.use_scan_query: - chunks = self._execute_scan_query(query, parameters) - else: - chunks = self._execute_dml(query, parameters) - - rows = self._rows_iterable(chunks) - # Prefetch the description: - try: - first_row = next(rows) - except StopIteration: - pass - else: - rows = itertools.chain((first_row,), rows) - if self.rows is not None: - rows = itertools.chain(self.rows, rows) - - self.rows = rows - - def _get_ydb_query(self, operation: YdbQuery) -> Union[ydb.DataQuery, str]: - pragma = "" - if self.root_directory: - pragma = f'PRAGMA TablePathPrefix = "{self.root_directory}";\n' - - yql_with_pragma = pragma + operation.yql_text - - if operation.is_ddl or not operation.parameters_types: - return yql_with_pragma - - return self._make_data_query(yql_with_pragma, operation.parameters_types) - - def _make_data_query( - self, - yql_text: str, - parameters_types: Dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]], - ) -> ydb.DataQuery: - """ - ydb.DataQuery uses hashed SQL text as cache key, which may cause issues if parameters change type within - the same session, so we include parameter types to the key to prevent false positive cache hit. - """ - - sorted_parameters = sorted(parameters_types.items()) # dict keys are unique, so the sorting is stable - - yql_with_params = yql_text + "".join([k + str(v) for k, v in sorted_parameters]) - name = hashlib.sha256(yql_with_params.encode("utf-8")).hexdigest() - return ydb.DataQuery(yql_text, parameters_types, name=name) - - @_handle_ydb_errors - def _execute_scan_query( - self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None - ) -> Generator[ydb.convert.ResultSet, None, None]: - settings = self._get_request_settings() - prepared_query = query - if isinstance(query, str) and parameters: - prepared_query: ydb.DataQuery = self._retry_operation_in_pool(self._prepare, query, settings) - - if isinstance(query, str): - scan_query = ydb.ScanQuery(query, None) - else: - scan_query = ydb.ScanQuery(prepared_query.yql_text, prepared_query.parameters_types) - - return self._execute_scan_query_in_driver(scan_query, parameters, settings) - - @_handle_ydb_errors - def _execute_dml( - self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None - ) -> ydb.convert.ResultSets: - settings = self._get_request_settings() - prepared_query = query - if isinstance(query, str) and parameters: - if self.tx_context: - prepared_query = self._run_operation_in_session(self._prepare, query, settings) - else: - prepared_query = self._retry_operation_in_pool(self._prepare, query, settings) - - if self.tx_context: - return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters, settings) - - return self._retry_operation_in_pool( - self._execute_in_session, self.tx_mode, prepared_query, parameters, settings - ) - - @_handle_ydb_errors - def _execute_ddl(self, query: str) -> ydb.convert.ResultSets: - settings = self._get_request_settings() - return self._retry_operation_in_pool(self._execute_scheme, query, settings) - - @staticmethod - def _execute_scheme( - session: ydb.Session, - query: str, - settings: ydb.BaseRequestSettings, - ) -> ydb.convert.ResultSets: - return session.execute_scheme(query, settings) - - @staticmethod - def _describe_table( - session: ydb.Session, abs_table_path: str, settings: ydb.BaseRequestSettings - ) -> ydb.TableDescription: - return session.describe_table(abs_table_path, settings) - - @staticmethod - def _describe_path(session: ydb.Session, table_path: str, settings: ydb.BaseRequestSettings) -> ydb.SchemeEntry: - return session._driver.scheme_client.describe_path(table_path, settings) - - @staticmethod - def _list_directory(session: ydb.Session, abs_dir_path: str, settings: ydb.BaseRequestSettings) -> ydb.Directory: - return session._driver.scheme_client.list_directory(abs_dir_path, settings) - - @staticmethod - def _prepare(session: ydb.Session, query: str, settings: ydb.BaseRequestSettings) -> ydb.DataQuery: - return session.prepare(query, settings) - - @staticmethod - def _execute_in_tx( - tx_context: ydb.TxContext, - prepared_query: ydb.DataQuery, - parameters: Optional[Mapping[str, Any]], - settings: ydb.BaseRequestSettings, - ) -> ydb.convert.ResultSets: - return tx_context.execute(prepared_query, parameters, commit_tx=False, settings=settings) - - @staticmethod - def _execute_in_session( - session: ydb.Session, - tx_mode: ydb.AbstractTransactionModeBuilder, - prepared_query: ydb.DataQuery, - parameters: Optional[Mapping[str, Any]], - settings: ydb.BaseRequestSettings, - ) -> ydb.convert.ResultSets: - return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True, settings=settings) - - def _execute_scan_query_in_driver( - self, - scan_query: ydb.ScanQuery, - parameters: Optional[Mapping[str, Any]], - settings: ydb.BaseRequestSettings, - ) -> Generator[ydb.convert.ResultSet, None, None]: - chunk: ydb.ScanQueryResult - for chunk in self.driver.table_client.scan_query(scan_query, parameters, settings): - yield chunk.result_set - - def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs): - return callee(self.tx_context, *args, **kwargs) - - def _run_operation_in_session(self, callee: collections.abc.Callable, *args, **kwargs): - return callee(self.tx_context.session, *args, **kwargs) - - def _retry_operation_in_pool(self, callee: collections.abc.Callable, *args, **kwargs): - return self.session_pool.retry_operation_sync(callee, None, *args, **kwargs) - - def _rows_iterable(self, chunks_iterable: ydb.convert.ResultSets): - try: - for chunk in chunks_iterable: - self.description = [ - ( - col.name, - get_column_type(col.type), - None, - None, - None, - None, - None, - ) - for col in chunk.columns - ] - for row in chunk.rows: - # returns tuple to be compatible with SqlAlchemy and because - # of this PEP to return a sequence: https://www.python.org/dev/peps/pep-0249/#fetchmany - yield row[::] - except ydb.Error as e: - raise DatabaseError(e.message, original_error=e) from e - - def _ensure_prefetched(self): - if self.rows is not None and self._rows_prefetched is None: - self._rows_prefetched = list(self.rows) - self.rows = iter(self._rows_prefetched) - return self._rows_prefetched - - def executemany(self, operation: YdbQuery, seq_of_parameters: Optional[Sequence[Mapping[str, Any]]]): - for parameters in seq_of_parameters: - self.execute(operation, parameters) - - def executescript(self, script): - return self.execute(script) - - def fetchone(self): - return next(self.rows or iter([]), None) - - def fetchmany(self, size=None): - return list(itertools.islice(self.rows, size or self.arraysize)) - - def fetchall(self): - return list(self.rows) - - def nextset(self): - self.fetchall() - - def setinputsizes(self, sizes): - pass - - def setoutputsize(self, column=None): - pass - - def close(self): - self.rows = None - self._rows_prefetched = None - - @property - def rowcount(self): - return len(self._ensure_prefetched()) - - def _get_request_settings(self) -> ydb.BaseRequestSettings: - settings = self.request_settings.make_copy() - - if self.request_settings.trace_id is None: - settings = settings.with_trace_id(maybe_get_current_trace_id()) - - return settings - - -class AsyncCursor(Cursor): - _await = staticmethod(util.await_only) - - @staticmethod - async def _describe_table( - session: ydb.aio.table.Session, - abs_table_path: str, - settings: ydb.BaseRequestSettings, - ) -> ydb.TableDescription: - return await session.describe_table(abs_table_path, settings) - - @staticmethod - async def _describe_path( - session: ydb.aio.table.Session, - abs_table_path: str, - settings: ydb.BaseRequestSettings, - ) -> ydb.SchemeEntry: - return await session._driver.scheme_client.describe_path(abs_table_path, settings) - - @staticmethod - async def _list_directory( - session: ydb.aio.table.Session, - abs_dir_path: str, - settings: ydb.BaseRequestSettings, - ) -> ydb.Directory: - return await session._driver.scheme_client.list_directory(abs_dir_path, settings) - - @staticmethod - async def _execute_scheme( - session: ydb.aio.table.Session, - query: str, - settings: ydb.BaseRequestSettings, - ) -> ydb.convert.ResultSets: - return await session.execute_scheme(query, settings) - - @staticmethod - async def _prepare( - session: ydb.aio.table.Session, - query: str, - settings: ydb.BaseRequestSettings, - ) -> ydb.DataQuery: - return await session.prepare(query, settings) - - @staticmethod - async def _execute_in_tx( - tx_context: ydb.aio.table.TxContext, - prepared_query: ydb.DataQuery, - parameters: Optional[Mapping[str, Any]], - settings: ydb.BaseRequestSettings, - ) -> ydb.convert.ResultSets: - return await tx_context.execute(prepared_query, parameters, commit_tx=False, settings=settings) - - @staticmethod - async def _execute_in_session( - session: ydb.aio.table.Session, - tx_mode: ydb.AbstractTransactionModeBuilder, - prepared_query: ydb.DataQuery, - parameters: Optional[Mapping[str, Any]], - settings: ydb.BaseRequestSettings, - ) -> ydb.convert.ResultSets: - return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True, settings=settings) - - def _execute_scan_query_in_driver( - self, - scan_query: ydb.ScanQuery, - parameters: Optional[Mapping[str, Any]], - settings: ydb.BaseRequestSettings, - ) -> Generator[ydb.convert.ResultSet, None, None]: - iterator: AsyncIterator[ydb.ScanQueryResult] = self._await( - self.driver.table_client.scan_query(scan_query, parameters, settings) - ) - while True: - try: - result = self._await(iterator.__anext__()) - yield result.result_set - except StopAsyncIteration: - break - - def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs): - return self._await(callee(self.tx_context, *args, **kwargs)) - - def _run_operation_in_session(self, callee: collections.abc.Coroutine, *args, **kwargs): - return self._await(callee(self.tx_context.session, *args, **kwargs)) - - def _retry_operation_in_pool(self, callee: collections.abc.Coroutine, *args, **kwargs): - return self._await(self.session_pool.retry_operation(callee, *args, **kwargs)) diff --git a/ydb_sqlalchemy/dbapi/errors.py b/ydb_sqlalchemy/dbapi/errors.py deleted file mode 100644 index 79faba8..0000000 --- a/ydb_sqlalchemy/dbapi/errors.py +++ /dev/null @@ -1,103 +0,0 @@ -from typing import List, Optional - -import ydb -from google.protobuf.message import Message - - -class Warning(Exception): - pass - - -class Error(Exception): - def __init__( - self, - message: str, - original_error: Optional[ydb.Error] = None, - ): - super(Error, self).__init__(message) - - self.original_error = original_error - if original_error: - pretty_issues = _pretty_issues(original_error.issues) - self.issues = original_error.issues - self.message = pretty_issues or message - self.status = original_error.status - - -class InterfaceError(Error): - pass - - -class DatabaseError(Error): - pass - - -class DataError(DatabaseError): - pass - - -class OperationalError(DatabaseError): - pass - - -class IntegrityError(DatabaseError): - pass - - -class InternalError(DatabaseError): - pass - - -class ProgrammingError(DatabaseError): - pass - - -class NotSupportedError(DatabaseError): - pass - - -def _pretty_issues(issues: List[Message]) -> str: - if issues is None: - return None - - children_messages = [_get_messages(issue, root=True) for issue in issues] - - if None in children_messages: - return None - - return "\n" + "\n".join(children_messages) - - -def _get_messages(issue: Message, max_depth: int = 100, indent: int = 2, depth: int = 0, root: bool = False) -> str: - if depth >= max_depth: - return None - - margin_str = " " * depth * indent - pre_message = "" - children = "" - - if issue.issues: - collapsed_messages = [] - while not root and len(issue.issues) == 1: - collapsed_messages.append(issue.message) - issue = issue.issues[0] - - if collapsed_messages: - pre_message = f"{margin_str}{', '.join(collapsed_messages)}\n" - depth += 1 - margin_str = " " * depth * indent - - children_messages = [ - _get_messages(iss, max_depth=max_depth, indent=indent, depth=depth + 1) for iss in issue.issues - ] - - if None in children_messages: - return None - - children = "\n".join(children_messages) - - return ( - f"{pre_message}{margin_str}{issue.message}\n{margin_str}" - f"severity level: {issue.severity}\n{margin_str}" - f"issue code: {issue.issue_code}\n{children}" - ) diff --git a/ydb_sqlalchemy/driver/wrapper.py b/ydb_sqlalchemy/driver/wrapper.py new file mode 100644 index 0000000..d777740 --- /dev/null +++ b/ydb_sqlalchemy/driver/wrapper.py @@ -0,0 +1,106 @@ +from sqlalchemy.engine.interfaces import AdaptedConnection + +from sqlalchemy.util.concurrency import await_only +from ydb_dbapi import AsyncConnection, AsyncCursor +import ydb + + +class AdaptedAsyncConnection(AdaptedConnection): + def __init__(self, connection: AsyncConnection): + self._connection: AsyncConnection = connection + + @property + def _driver(self): + return self._connection._driver + + @property + def _session_pool(self): + return self._connection._session_pool + + @property + def _tx_context(self): + return self._connection._tx_context + + @property + def _tx_mode(self): + return self._connection._tx_mode + + @property + def interactive_transaction(self): + return self._connection.interactive_transaction + + def cursor(self): + return AdaptedAsyncCursor(self._connection.cursor()) + + def begin(self): + return await_only(self._connection.begin()) + + def commit(self): + return await_only(self._connection.commit()) + + def rollback(self): + return await_only(self._connection.rollback()) + + def close(self): + return await_only(self._connection.close()) + + def set_isolation_level(self, level): + return self._connection.set_isolation_level(level) + + def get_isolation_level(self): + return self._connection.get_isolation_level() + + def set_ydb_request_settings(self, value: ydb.BaseRequestSettings) -> None: + self._connection.set_ydb_request_settings(value) + + def get_ydb_request_settings(self) -> ydb.BaseRequestSettings: + return self._connection.get_ydb_request_settings() + + def describe(self, table_path: str): + return await_only(self._connection.describe(table_path)) + + def check_exists(self, table_path: str): + return await_only(self._connection.check_exists(table_path)) + + def get_table_names(self): + return await_only(self._connection.get_table_names()) + + +class AdaptedAsyncCursor: + def __init__(self, cursor: AsyncCursor): + self._cursor = cursor + + @property + def description(self): + return self._cursor.description + + @property + def rowcount(self): + return self._cursor.rowcount + + def fetchone(self): + return await_only(self._cursor.fetchone()) + + def fetchmany(self, size=None): + return await_only(self._cursor.fetchmany(size=size)) + + def fetchall(self): + return await_only(self._cursor.fetchall()) + + def execute_scheme(self, sql, parameters=None): + return await_only(self._cursor.execute_scheme(sql, parameters)) + + def execute(self, sql, parameters=None): + return await_only(self._cursor.execute(sql, parameters)) + + def executemany(self, sql, parameters=None): + return await_only(self._cursor.executemany(sql, parameters)) + + def close(self): + return await_only(self._cursor.close()) + + def setinputsizes(self, *args): + pass + + def setoutputsizes(self, *args): + pass diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 8f0d76f..135eada 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -24,8 +24,8 @@ from sqlalchemy.sql.elements import ClauseList from sqlalchemy.util.compat import inspect_getfullargspec -import ydb_sqlalchemy.dbapi as dbapi -from ydb_sqlalchemy.dbapi.constants import YDB_KEYWORDS +import ydb_dbapi as dbapi +from ydb_sqlalchemy.driver.wrapper import AdaptedAsyncConnection from ydb_sqlalchemy.sqlalchemy.dml import Upsert from . import types @@ -54,7 +54,7 @@ class YqlIdentifierPreparer(IdentifierPreparer): reserved_words = IdentifierPreparer.reserved_words - reserved_words.update(YDB_KEYWORDS) + reserved_words.update(dbapi.YDB_KEYWORDS) def __init__(self, dialect): super(YqlIdentifierPreparer, self).__init__( @@ -268,7 +268,7 @@ def limit_clause(self, select, **kw): select._offset_clause, types.UInt64, skip_types=(types.UInt64, types.UInt32, types.UInt16, types.UInt8) ) if select._limit_clause is None: - text += "\n LIMIT NULL" + text += "\n LIMIT 1000" # For some reason, YDB do not support LIMIT NULL OFFSET text += " OFFSET " + self.process(offset_clause, **kw) return text @@ -578,17 +578,6 @@ def _get_column_info(t): return COLUMN_TYPES[t], nullable -class YdbScanQueryCharacteristic(characteristics.ConnectionCharacteristic): - def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> None: - dialect.reset_ydb_scan_query(dbapi_connection) - - def set_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection, value: bool) -> None: - dialect.set_ydb_scan_query(dbapi_connection, value) - - def get_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> bool: - return dialect.get_ydb_scan_query(dbapi_connection) - - class YdbRequestSettingsCharacteristic(characteristics.ConnectionCharacteristic): def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> None: dialect.reset_ydb_request_settings(dbapi_connection) @@ -650,7 +639,6 @@ class YqlDialect(StrCompileDialect): connection_characteristics = util.immutabledict( { "isolation_level": characteristics.IsolationLevelCharacteristic(), - "ydb_scan_query": YdbScanQueryCharacteristic(), "ydb_request_settings": YdbRequestSettingsCharacteristic(), } ) @@ -679,7 +667,7 @@ class YqlDialect(StrCompileDialect): @classmethod def import_dbapi(cls: Any): - return dbapi.YdbDBApi() + return dbapi def __init__( self, @@ -778,15 +766,6 @@ def get_default_isolation_level(self, dbapi_conn: dbapi.Connection) -> str: def get_isolation_level(self, dbapi_connection: dbapi.Connection) -> str: return dbapi_connection.get_isolation_level() - def set_ydb_scan_query(self, dbapi_connection: dbapi.Connection, value: bool) -> None: - dbapi_connection.set_ydb_scan_query(value) - - def reset_ydb_scan_query(self, dbapi_connection: dbapi.Connection): - self.set_ydb_scan_query(dbapi_connection, False) - - def get_ydb_scan_query(self, dbapi_connection: dbapi.Connection) -> bool: - return dbapi_connection.get_ydb_scan_query() - def set_ydb_request_settings( self, dbapi_connection: dbapi.Connection, @@ -853,25 +832,43 @@ def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types): ) return f"{declarations}\n{statement}" + def __merge_parameters_values_and_types( + self, values: Mapping[str, Any], types: Mapping[str, Any], execute_many: bool + ) -> Sequence[Mapping[str, ydb.TypedValue]]: + if isinstance(values, collections.abc.Mapping): + values = [values] + + result_list = [] + for value_map in values: + result = {} + for key in value_map.keys(): + if key in types: + result[key] = ydb.TypedValue(value_map[key], types[key]) + else: + result[key] = values[key] + result_list.append(result) + return result_list if execute_many else result_list[0] + def _make_ydb_operation( self, statement: str, context: Optional[DefaultExecutionContext] = None, parameters: Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]] = None, execute_many: bool = False, - ) -> Tuple[dbapi.YdbQuery, Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]: + ) -> Tuple[Optional[Union[Sequence[Mapping[str, Any]], Mapping[str, Any]]]]: is_ddl = context.isddl if context is not None else False if not is_ddl and parameters: parameters_types = context.compiled.get_bind_types(parameters) - parameters_types = {f"${k}": v for k, v in parameters_types.items()} + if parameters_types != {}: + parameters = self.__merge_parameters_values_and_types(parameters, parameters_types, execute_many) statement, parameters = self._format_variables(statement, parameters, execute_many) if self._add_declare_for_yql_stmt_vars: statement = self._add_declare_for_yql_stmt_vars_impl(statement, parameters_types) - return dbapi.YdbQuery(yql_text=statement, parameters_types=parameters_types, is_ddl=is_ddl), parameters + return statement, parameters statement, parameters = self._format_variables(statement, parameters, execute_many) - return dbapi.YdbQuery(yql_text=statement, is_ddl=is_ddl), parameters + return statement, parameters def do_ping(self, dbapi_connection: dbapi.Connection) -> bool: cursor = dbapi_connection.cursor() @@ -900,7 +897,11 @@ def do_execute( context: Optional[DefaultExecutionContext] = None, ) -> None: operation, parameters = self._make_ydb_operation(statement, context, parameters, execute_many=False) - cursor.execute(operation, parameters) + is_ddl = context.isddl if context is not None else False + if is_ddl: + cursor.execute_scheme(operation, parameters) + else: + cursor.execute(operation, parameters) class AsyncYqlDialect(YqlDialect): @@ -909,4 +910,4 @@ class AsyncYqlDialect(YqlDialect): supports_statement_cache = True def connect(self, *cargs, **cparams): - return self.loaded_dbapi.async_connect(*cargs, **cparams) + return AdaptedAsyncConnection(util.await_only(self.loaded_dbapi.async_connect(*cargs, **cparams)))