diff --git a/ydb_sqlalchemy/dbapi/connection.py b/ydb_sqlalchemy/dbapi/connection.py index fa0941e..dc4ac97 100644 --- a/ydb_sqlalchemy/dbapi/connection.py +++ b/ydb_sqlalchemy/dbapi/connection.py @@ -58,10 +58,17 @@ def __init__( self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite() self.tx_context: Optional[ydb.TxContext] = None self.use_scan_query: bool = False + self.request_settings: ydb.BaseRequestSettings = ydb.BaseRequestSettings() def cursor(self): return self._cursor_class( - self.driver, self.session_pool, self.tx_mode, self.tx_context, self.use_scan_query, self.table_path_prefix + driver=self.driver, + session_pool=self.session_pool, + tx_mode=self.tx_mode, + tx_context=self.tx_context, + request_settings=self.request_settings, + use_scan_query=self.use_scan_query, + table_path_prefix=self.table_path_prefix, ) def describe(self, table_path: str) -> ydb.TableDescription: @@ -69,7 +76,7 @@ def describe(self, table_path: str) -> ydb.TableDescription: cursor = self.cursor() return cursor.describe_table(abs_table_path) - def check_exists(self, table_path: str) -> ydb.SchemeEntry: + def check_exists(self, table_path: str) -> bool: abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path) cursor = self.cursor() return cursor.check_exists(abs_table_path) @@ -124,6 +131,12 @@ def set_ydb_scan_query(self, value: bool) -> None: def get_ydb_scan_query(self) -> bool: return self.use_scan_query + def set_ydb_request_settings(self, value: ydb.BaseRequestSettings) -> None: + self.request_settings = value + + def get_ydb_request_settings(self) -> ydb.BaseRequestSettings: + return self.request_settings + def begin(self): self.tx_context = None if self.interactive_transaction and not self.use_scan_query: diff --git a/ydb_sqlalchemy/dbapi/cursor.py b/ydb_sqlalchemy/dbapi/cursor.py index 22dfdaa..c104356 100644 --- a/ydb_sqlalchemy/dbapi/cursor.py +++ b/ydb_sqlalchemy/dbapi/cursor.py @@ -5,16 +5,7 @@ import itertools import posixpath from collections.abc import AsyncIterator -from typing import ( - Any, - Dict, - Generator, - List, - Mapping, - Optional, - Sequence, - Union, -) +from typing import Any, Dict, Generator, List, Mapping, Optional, Sequence, Union import ydb import ydb.aio @@ -29,6 +20,7 @@ OperationalError, ProgrammingError, ) +from .tracing import maybe_get_current_trace_id def get_column_type(type_obj: Any) -> str: @@ -87,6 +79,7 @@ def __init__( driver: Union[ydb.Driver, ydb.aio.Driver], session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool], tx_mode: ydb.AbstractTransactionModeBuilder, + request_settings: ydb.BaseRequestSettings, tx_context: Optional[ydb.BaseTxContext] = None, use_scan_query: bool = False, table_path_prefix: str = "", @@ -94,28 +87,32 @@ def __init__( self.driver = driver self.session_pool = session_pool self.tx_mode = tx_mode + self.request_settings = request_settings self.tx_context = tx_context self.use_scan_query = use_scan_query + self.root_directory = table_path_prefix self.description = None self.arraysize = 1 self.rows = None self._rows_prefetched = None - self.root_directory = table_path_prefix @_handle_ydb_errors def describe_table(self, abs_table_path: str) -> ydb.TableDescription: - return self._retry_operation_in_pool(self._describe_table, abs_table_path) + settings = self._get_request_settings() + return self._retry_operation_in_pool(self._describe_table, abs_table_path, settings) def check_exists(self, abs_table_path: str) -> bool: + settings = self._get_request_settings() try: - self._retry_operation_in_pool(self._describe_path, abs_table_path) + self._retry_operation_in_pool(self._describe_path, abs_table_path, settings) return True except ydb.SchemeError: return False @_handle_ydb_errors def get_table_names(self, abs_dir_path: str) -> List[str]: - directory: ydb.Directory = self._retry_operation_in_pool(self._list_directory, abs_dir_path) + settings = self._get_request_settings() + directory: ydb.Directory = self._retry_operation_in_pool(self._list_directory, abs_dir_path, settings) result = [] for child in directory.children: child_abs_path = posixpath.join(abs_dir_path, child.name) @@ -180,62 +177,76 @@ def _make_data_query( def _execute_scan_query( self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None ) -> Generator[ydb.convert.ResultSet, None, None]: + settings = self._get_request_settings() prepared_query = query if isinstance(query, str) and parameters: - prepared_query: ydb.DataQuery = self._retry_operation_in_pool(self._prepare, query) + prepared_query: ydb.DataQuery = self._retry_operation_in_pool(self._prepare, query, settings) if isinstance(query, str): scan_query = ydb.ScanQuery(query, None) else: scan_query = ydb.ScanQuery(prepared_query.yql_text, prepared_query.parameters_types) - return self._execute_scan_query_in_driver(scan_query, parameters) + return self._execute_scan_query_in_driver(scan_query, parameters, settings) @_handle_ydb_errors def _execute_dml( self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None ) -> ydb.convert.ResultSets: + settings = self._get_request_settings() prepared_query = query if isinstance(query, str) and parameters: if self.tx_context: - prepared_query = self._run_operation_in_session(self._prepare, query) + prepared_query = self._run_operation_in_session(self._prepare, query, settings) else: - prepared_query = self._retry_operation_in_pool(self._prepare, query) + prepared_query = self._retry_operation_in_pool(self._prepare, query, settings) if self.tx_context: - return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters) + return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters, settings) - return self._retry_operation_in_pool(self._execute_in_session, self.tx_mode, prepared_query, parameters) + return self._retry_operation_in_pool( + self._execute_in_session, self.tx_mode, prepared_query, parameters, settings + ) @_handle_ydb_errors def _execute_ddl(self, query: str) -> ydb.convert.ResultSets: - return self._retry_operation_in_pool(self._execute_scheme, query) + settings = self._get_request_settings() + return self._retry_operation_in_pool(self._execute_scheme, query, settings) @staticmethod - def _execute_scheme(session: ydb.Session, query: str) -> ydb.convert.ResultSets: - return session.execute_scheme(query) + def _execute_scheme( + session: ydb.Session, + query: str, + settings: ydb.BaseRequestSettings, + ) -> ydb.convert.ResultSets: + return session.execute_scheme(query, settings) @staticmethod - def _describe_table(session: ydb.Session, abs_table_path: str) -> ydb.TableDescription: - return session.describe_table(abs_table_path) + def _describe_table( + session: ydb.Session, abs_table_path: str, settings: ydb.BaseRequestSettings + ) -> ydb.TableDescription: + return session.describe_table(abs_table_path, settings) @staticmethod - def _describe_path(session: ydb.Session, table_path: str) -> ydb.SchemeEntry: - return session._driver.scheme_client.describe_path(table_path) + def _describe_path(session: ydb.Session, table_path: str, settings: ydb.BaseRequestSettings) -> ydb.SchemeEntry: + return session._driver.scheme_client.describe_path(table_path, settings) @staticmethod - def _list_directory(session: ydb.Session, abs_dir_path: str) -> ydb.Directory: - return session._driver.scheme_client.list_directory(abs_dir_path) + def _list_directory(session: ydb.Session, abs_dir_path: str, settings: ydb.BaseRequestSettings) -> ydb.Directory: + return session._driver.scheme_client.list_directory(abs_dir_path, settings) @staticmethod - def _prepare(session: ydb.Session, query: str) -> ydb.DataQuery: - return session.prepare(query) + def _prepare(session: ydb.Session, query: str, settings: ydb.BaseRequestSettings) -> ydb.DataQuery: + return session.prepare(query, settings) @staticmethod def _execute_in_tx( - tx_context: ydb.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + tx_context: ydb.TxContext, + prepared_query: ydb.DataQuery, + parameters: Optional[Mapping[str, Any]], + settings: ydb.BaseRequestSettings, ) -> ydb.convert.ResultSets: - return tx_context.execute(prepared_query, parameters, commit_tx=False) + return tx_context.execute(prepared_query, parameters, commit_tx=False, settings=settings) @staticmethod def _execute_in_session( @@ -243,16 +254,18 @@ def _execute_in_session( tx_mode: ydb.AbstractTransactionModeBuilder, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]], + settings: ydb.BaseRequestSettings, ) -> ydb.convert.ResultSets: - return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) + return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True, settings=settings) def _execute_scan_query_in_driver( self, scan_query: ydb.ScanQuery, parameters: Optional[Mapping[str, Any]], + settings: ydb.BaseRequestSettings, ) -> Generator[ydb.convert.ResultSet, None, None]: chunk: ydb.ScanQueryResult - for chunk in self.driver.table_client.scan_query(scan_query, parameters): + for chunk in self.driver.table_client.scan_query(scan_query, parameters, settings): yield chunk.result_set def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs): @@ -325,35 +338,66 @@ def close(self): def rowcount(self): return len(self._ensure_prefetched()) + def _get_request_settings(self) -> ydb.BaseRequestSettings: + settings = self.request_settings.make_copy() + + if self.request_settings.trace_id is None: + settings = settings.with_trace_id(maybe_get_current_trace_id()) + + return settings + class AsyncCursor(Cursor): _await = staticmethod(util.await_only) @staticmethod - async def _describe_table(session: ydb.aio.table.Session, abs_table_path: str) -> ydb.TableDescription: - return await session.describe_table(abs_table_path) + async def _describe_table( + session: ydb.aio.table.Session, + abs_table_path: str, + settings: ydb.BaseRequestSettings, + ) -> ydb.TableDescription: + return await session.describe_table(abs_table_path, settings) @staticmethod - async def _describe_path(session: ydb.aio.table.Session, abs_table_path: str) -> ydb.SchemeEntry: - return await session._driver.scheme_client.describe_path(abs_table_path) + async def _describe_path( + session: ydb.aio.table.Session, + abs_table_path: str, + settings: ydb.BaseRequestSettings, + ) -> ydb.SchemeEntry: + return await session._driver.scheme_client.describe_path(abs_table_path, settings) @staticmethod - async def _list_directory(session: ydb.aio.table.Session, abs_dir_path: str) -> ydb.Directory: - return await session._driver.scheme_client.list_directory(abs_dir_path) + async def _list_directory( + session: ydb.aio.table.Session, + abs_dir_path: str, + settings: ydb.BaseRequestSettings, + ) -> ydb.Directory: + return await session._driver.scheme_client.list_directory(abs_dir_path, settings) @staticmethod - async def _execute_scheme(session: ydb.aio.table.Session, query: str) -> ydb.convert.ResultSets: - return await session.execute_scheme(query) + async def _execute_scheme( + session: ydb.aio.table.Session, + query: str, + settings: ydb.BaseRequestSettings, + ) -> ydb.convert.ResultSets: + return await session.execute_scheme(query, settings) @staticmethod - async def _prepare(session: ydb.aio.table.Session, query: str) -> ydb.DataQuery: - return await session.prepare(query) + async def _prepare( + session: ydb.aio.table.Session, + query: str, + settings: ydb.BaseRequestSettings, + ) -> ydb.DataQuery: + return await session.prepare(query, settings) @staticmethod async def _execute_in_tx( - tx_context: ydb.aio.table.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] + tx_context: ydb.aio.table.TxContext, + prepared_query: ydb.DataQuery, + parameters: Optional[Mapping[str, Any]], + settings: ydb.BaseRequestSettings, ) -> ydb.convert.ResultSets: - return await tx_context.execute(prepared_query, parameters, commit_tx=False) + return await tx_context.execute(prepared_query, parameters, commit_tx=False, settings=settings) @staticmethod async def _execute_in_session( @@ -361,16 +405,18 @@ async def _execute_in_session( tx_mode: ydb.AbstractTransactionModeBuilder, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]], + settings: ydb.BaseRequestSettings, ) -> ydb.convert.ResultSets: - return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) + return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True, settings=settings) def _execute_scan_query_in_driver( self, scan_query: ydb.ScanQuery, parameters: Optional[Mapping[str, Any]], + settings: ydb.BaseRequestSettings, ) -> Generator[ydb.convert.ResultSet, None, None]: iterator: AsyncIterator[ydb.ScanQueryResult] = self._await( - self.driver.table_client.scan_query(scan_query, parameters) + self.driver.table_client.scan_query(scan_query, parameters, settings) ) while True: try: diff --git a/ydb_sqlalchemy/dbapi/tracing.py b/ydb_sqlalchemy/dbapi/tracing.py new file mode 100644 index 0000000..86f4ee8 --- /dev/null +++ b/ydb_sqlalchemy/dbapi/tracing.py @@ -0,0 +1,16 @@ +import importlib.util +from typing import Optional + + +def maybe_get_current_trace_id() -> Optional[str]: + # Check if OpenTelemetry is available + if importlib.util.find_spec("opentelemetry"): + from opentelemetry import trace + + current_span = trace.get_current_span() + + if current_span.get_span_context().is_valid: + return format(current_span.get_span_context().trace_id, "032x") + + # Return None if OpenTelemetry is not available or trace ID is invalid + return None diff --git a/ydb_sqlalchemy/sqlalchemy/__init__.py b/ydb_sqlalchemy/sqlalchemy/__init__.py index 0890e0d..8f0d76f 100644 --- a/ydb_sqlalchemy/sqlalchemy/__init__.py +++ b/ydb_sqlalchemy/sqlalchemy/__init__.py @@ -585,8 +585,21 @@ def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Co def set_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection, value: bool) -> None: dialect.set_ydb_scan_query(dbapi_connection, value) - def get_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> Any: - dialect.get_ydb_scan_query(dbapi_connection) + def get_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> bool: + return dialect.get_ydb_scan_query(dbapi_connection) + + +class YdbRequestSettingsCharacteristic(characteristics.ConnectionCharacteristic): + def reset_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> None: + dialect.reset_ydb_request_settings(dbapi_connection) + + def set_characteristic( + self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection, value: ydb.BaseRequestSettings + ) -> None: + dialect.set_ydb_request_settings(dbapi_connection, value) + + def get_characteristic(self, dialect: "YqlDialect", dbapi_connection: dbapi.Connection) -> ydb.BaseRequestSettings: + return dialect.get_ydb_request_settings(dbapi_connection) class YqlDialect(StrCompileDialect): @@ -638,6 +651,7 @@ class YqlDialect(StrCompileDialect): { "isolation_level": characteristics.IsolationLevelCharacteristic(), "ydb_scan_query": YdbScanQueryCharacteristic(), + "ydb_request_settings": YdbRequestSettingsCharacteristic(), } ) @@ -770,9 +784,22 @@ def set_ydb_scan_query(self, dbapi_connection: dbapi.Connection, value: bool) -> def reset_ydb_scan_query(self, dbapi_connection: dbapi.Connection): self.set_ydb_scan_query(dbapi_connection, False) - def get_ydb_scan_query(self, dbapi_connection: dbapi.Connection) -> str: + def get_ydb_scan_query(self, dbapi_connection: dbapi.Connection) -> bool: return dbapi_connection.get_ydb_scan_query() + def set_ydb_request_settings( + self, + dbapi_connection: dbapi.Connection, + value: ydb.BaseRequestSettings, + ) -> None: + dbapi_connection.set_ydb_request_settings(value) + + def reset_ydb_request_settings(self, dbapi_connection: dbapi.Connection): + self.set_ydb_request_settings(dbapi_connection, ydb.BaseRequestSettings()) + + def get_ydb_request_settings(self, dbapi_connection: dbapi.Connection) -> ydb.BaseRequestSettings: + return dbapi_connection.get_ydb_request_settings() + def connect(self, *cargs, **cparams): return self.loaded_dbapi.connect(*cargs, **cparams) diff --git a/ydb_sqlalchemy/sqlalchemy/types.py b/ydb_sqlalchemy/sqlalchemy/types.py index c97a3e0..557ce3d 100644 --- a/ydb_sqlalchemy/sqlalchemy/types.py +++ b/ydb_sqlalchemy/sqlalchemy/types.py @@ -3,7 +3,7 @@ from sqlalchemy import ARRAY, ColumnElement, exc, types from sqlalchemy.sql import type_api -from .datetime_types import YqlTimestamp, YqlDateTime # noqa: F401 +from .datetime_types import YqlDateTime, YqlTimestamp # noqa: F401 from .json import YqlJSON # noqa: F401