Skip to content

Commit

Permalink
Add request settings: request_timeout and trace_id
Browse files Browse the repository at this point in the history
  • Loading branch information
Roman Tretiak committed Oct 31, 2024
1 parent 3ff29ac commit 06f35f9
Show file tree
Hide file tree
Showing 5 changed files with 157 additions and 55 deletions.
17 changes: 15 additions & 2 deletions ydb_sqlalchemy/dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,25 @@ def __init__(
self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite()
self.tx_context: Optional[ydb.TxContext] = None
self.use_scan_query: bool = False
self.request_settings: ydb.BaseRequestSettings = ydb.BaseRequestSettings()

def cursor(self):
return self._cursor_class(
self.driver, self.session_pool, self.tx_mode, self.tx_context, self.use_scan_query, self.table_path_prefix
driver=self.driver,
session_pool=self.session_pool,
tx_mode=self.tx_mode,
tx_context=self.tx_context,
request_settings=self.request_settings,
use_scan_query=self.use_scan_query,
table_path_prefix=self.table_path_prefix,
)

def describe(self, table_path: str) -> ydb.TableDescription:
abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path)
cursor = self.cursor()
return cursor.describe_table(abs_table_path)

def check_exists(self, table_path: str) -> ydb.SchemeEntry:
def check_exists(self, table_path: str) -> bool:
abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path)
cursor = self.cursor()
return cursor.check_exists(abs_table_path)
Expand Down Expand Up @@ -124,6 +131,12 @@ def set_ydb_scan_query(self, value: bool) -> None:
def get_ydb_scan_query(self) -> bool:
return self.use_scan_query

def set_ydb_request_settings(self, value: ydb.BaseRequestSettings) -> None:
self.request_settings = value

def get_ydb_request_settings(self) -> ydb.BaseRequestSettings:
return self.request_settings

def begin(self):
self.tx_context = None
if self.interactive_transaction and not self.use_scan_query:
Expand Down
144 changes: 95 additions & 49 deletions ydb_sqlalchemy/dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,7 @@
import itertools
import posixpath
from collections.abc import AsyncIterator
from typing import (
Any,
Dict,
Generator,
List,
Mapping,
Optional,
Sequence,
Union,
)
from typing import Any, Dict, Generator, List, Mapping, Optional, Sequence, Union

import ydb
import ydb.aio
Expand All @@ -29,6 +20,7 @@
OperationalError,
ProgrammingError,
)
from .tracing import maybe_get_current_trace_id


def get_column_type(type_obj: Any) -> str:
Expand Down Expand Up @@ -87,35 +79,40 @@ def __init__(
driver: Union[ydb.Driver, ydb.aio.Driver],
session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool],
tx_mode: ydb.AbstractTransactionModeBuilder,
request_settings: ydb.BaseRequestSettings,
tx_context: Optional[ydb.BaseTxContext] = None,
use_scan_query: bool = False,
table_path_prefix: str = "",
):
self.driver = driver
self.session_pool = session_pool
self.tx_mode = tx_mode
self.request_settings = request_settings
self.tx_context = tx_context
self.use_scan_query = use_scan_query
self.root_directory = table_path_prefix
self.description = None
self.arraysize = 1
self.rows = None
self._rows_prefetched = None
self.root_directory = table_path_prefix

@_handle_ydb_errors
def describe_table(self, abs_table_path: str) -> ydb.TableDescription:
return self._retry_operation_in_pool(self._describe_table, abs_table_path)
settings = self._get_request_settings()
return self._retry_operation_in_pool(self._describe_table, abs_table_path, settings)

def check_exists(self, abs_table_path: str) -> bool:
settings = self._get_request_settings()
try:
self._retry_operation_in_pool(self._describe_path, abs_table_path)
self._retry_operation_in_pool(self._describe_path, abs_table_path, settings)
return True
except ydb.SchemeError:
return False

@_handle_ydb_errors
def get_table_names(self, abs_dir_path: str) -> List[str]:
directory: ydb.Directory = self._retry_operation_in_pool(self._list_directory, abs_dir_path)
settings = self._get_request_settings()
directory: ydb.Directory = self._retry_operation_in_pool(self._list_directory, abs_dir_path, settings)
result = []
for child in directory.children:
child_abs_path = posixpath.join(abs_dir_path, child.name)
Expand Down Expand Up @@ -180,79 +177,95 @@ def _make_data_query(
def _execute_scan_query(
self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None
) -> Generator[ydb.convert.ResultSet, None, None]:
settings = self._get_request_settings()
prepared_query = query
if isinstance(query, str) and parameters:
prepared_query: ydb.DataQuery = self._retry_operation_in_pool(self._prepare, query)
prepared_query: ydb.DataQuery = self._retry_operation_in_pool(self._prepare, query, settings)

if isinstance(query, str):
scan_query = ydb.ScanQuery(query, None)
else:
scan_query = ydb.ScanQuery(prepared_query.yql_text, prepared_query.parameters_types)

return self._execute_scan_query_in_driver(scan_query, parameters)
return self._execute_scan_query_in_driver(scan_query, parameters, settings)

@_handle_ydb_errors
def _execute_dml(
self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None
) -> ydb.convert.ResultSets:
settings = self._get_request_settings()
prepared_query = query
if isinstance(query, str) and parameters:
if self.tx_context:
prepared_query = self._run_operation_in_session(self._prepare, query)
prepared_query = self._run_operation_in_session(self._prepare, query, settings)
else:
prepared_query = self._retry_operation_in_pool(self._prepare, query)
prepared_query = self._retry_operation_in_pool(self._prepare, query, settings)

if self.tx_context:
return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters)
return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters, settings)

return self._retry_operation_in_pool(self._execute_in_session, self.tx_mode, prepared_query, parameters)
return self._retry_operation_in_pool(
self._execute_in_session, self.tx_mode, prepared_query, parameters, settings
)

@_handle_ydb_errors
def _execute_ddl(self, query: str) -> ydb.convert.ResultSets:
return self._retry_operation_in_pool(self._execute_scheme, query)
settings = self._get_request_settings()
return self._retry_operation_in_pool(self._execute_scheme, query, settings)

@staticmethod
def _execute_scheme(session: ydb.Session, query: str) -> ydb.convert.ResultSets:
return session.execute_scheme(query)
def _execute_scheme(
session: ydb.Session,
query: str,
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return session.execute_scheme(query, settings)

@staticmethod
def _describe_table(session: ydb.Session, abs_table_path: str) -> ydb.TableDescription:
return session.describe_table(abs_table_path)
def _describe_table(
session: ydb.Session, abs_table_path: str, settings: ydb.BaseRequestSettings
) -> ydb.TableDescription:
return session.describe_table(abs_table_path, settings)

@staticmethod
def _describe_path(session: ydb.Session, table_path: str) -> ydb.SchemeEntry:
return session._driver.scheme_client.describe_path(table_path)
def _describe_path(session: ydb.Session, table_path: str, settings: ydb.BaseRequestSettings) -> ydb.SchemeEntry:
return session._driver.scheme_client.describe_path(table_path, settings)

@staticmethod
def _list_directory(session: ydb.Session, abs_dir_path: str) -> ydb.Directory:
return session._driver.scheme_client.list_directory(abs_dir_path)
def _list_directory(session: ydb.Session, abs_dir_path: str, settings: ydb.BaseRequestSettings) -> ydb.Directory:
return session._driver.scheme_client.list_directory(abs_dir_path, settings)

@staticmethod
def _prepare(session: ydb.Session, query: str) -> ydb.DataQuery:
return session.prepare(query)
def _prepare(session: ydb.Session, query: str, settings: ydb.BaseRequestSettings) -> ydb.DataQuery:
return session.prepare(query, settings)

@staticmethod
def _execute_in_tx(
tx_context: ydb.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
tx_context: ydb.TxContext,
prepared_query: ydb.DataQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return tx_context.execute(prepared_query, parameters, commit_tx=False)
return tx_context.execute(prepared_query, parameters, commit_tx=False, settings=settings)

@staticmethod
def _execute_in_session(
session: ydb.Session,
tx_mode: ydb.AbstractTransactionModeBuilder,
prepared_query: ydb.DataQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)
return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True, settings=settings)

def _execute_scan_query_in_driver(
self,
scan_query: ydb.ScanQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> Generator[ydb.convert.ResultSet, None, None]:
chunk: ydb.ScanQueryResult
for chunk in self.driver.table_client.scan_query(scan_query, parameters):
for chunk in self.driver.table_client.scan_query(scan_query, parameters, settings):
yield chunk.result_set

def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs):
Expand Down Expand Up @@ -325,52 +338,85 @@ def close(self):
def rowcount(self):
return len(self._ensure_prefetched())

def _get_request_settings(self) -> ydb.BaseRequestSettings:
settings = self.request_settings.make_copy()

if self.request_settings.trace_id is None:
settings = settings.with_trace_id(maybe_get_current_trace_id())

return settings


class AsyncCursor(Cursor):
_await = staticmethod(util.await_only)

@staticmethod
async def _describe_table(session: ydb.aio.table.Session, abs_table_path: str) -> ydb.TableDescription:
return await session.describe_table(abs_table_path)
async def _describe_table(
session: ydb.aio.table.Session,
abs_table_path: str,
settings: ydb.BaseRequestSettings,
) -> ydb.TableDescription:
return await session.describe_table(abs_table_path, settings)

@staticmethod
async def _describe_path(session: ydb.aio.table.Session, abs_table_path: str) -> ydb.SchemeEntry:
return await session._driver.scheme_client.describe_path(abs_table_path)
async def _describe_path(
session: ydb.aio.table.Session,
abs_table_path: str,
settings: ydb.BaseRequestSettings,
) -> ydb.SchemeEntry:
return await session._driver.scheme_client.describe_path(abs_table_path, settings)

@staticmethod
async def _list_directory(session: ydb.aio.table.Session, abs_dir_path: str) -> ydb.Directory:
return await session._driver.scheme_client.list_directory(abs_dir_path)
async def _list_directory(
session: ydb.aio.table.Session,
abs_dir_path: str,
settings: ydb.BaseRequestSettings,
) -> ydb.Directory:
return await session._driver.scheme_client.list_directory(abs_dir_path, settings)

@staticmethod
async def _execute_scheme(session: ydb.aio.table.Session, query: str) -> ydb.convert.ResultSets:
return await session.execute_scheme(query)
async def _execute_scheme(
session: ydb.aio.table.Session,
query: str,
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return await session.execute_scheme(query, settings)

@staticmethod
async def _prepare(session: ydb.aio.table.Session, query: str) -> ydb.DataQuery:
return await session.prepare(query)
async def _prepare(
session: ydb.aio.table.Session,
query: str,
settings: ydb.BaseRequestSettings,
) -> ydb.DataQuery:
return await session.prepare(query, settings)

@staticmethod
async def _execute_in_tx(
tx_context: ydb.aio.table.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
tx_context: ydb.aio.table.TxContext,
prepared_query: ydb.DataQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return await tx_context.execute(prepared_query, parameters, commit_tx=False)
return await tx_context.execute(prepared_query, parameters, commit_tx=False, settings=settings)

@staticmethod
async def _execute_in_session(
session: ydb.aio.table.Session,
tx_mode: ydb.AbstractTransactionModeBuilder,
prepared_query: ydb.DataQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)
return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True, settings=settings)

def _execute_scan_query_in_driver(
self,
scan_query: ydb.ScanQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> Generator[ydb.convert.ResultSet, None, None]:
iterator: AsyncIterator[ydb.ScanQueryResult] = self._await(
self.driver.table_client.scan_query(scan_query, parameters)
self.driver.table_client.scan_query(scan_query, parameters, settings)
)
while True:
try:
Expand Down
16 changes: 16 additions & 0 deletions ydb_sqlalchemy/dbapi/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import importlib.util
from typing import Optional


def maybe_get_current_trace_id() -> Optional[str]:
# Check if OpenTelemetry is available
if importlib.util.find_spec("opentelemetry"):
from opentelemetry import trace

current_span = trace.get_current_span()

if current_span.get_span_context().is_valid:
return format(current_span.get_span_context().trace_id, "032x")

# Return None if OpenTelemetry is not available or trace ID is invalid
return None
Loading

0 comments on commit 06f35f9

Please sign in to comment.