Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add request settings to execution options #59

Merged
merged 1 commit into from
Oct 31, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions ydb_sqlalchemy/dbapi/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,18 +58,25 @@ def __init__(
self.tx_mode: ydb.AbstractTransactionModeBuilder = ydb.SerializableReadWrite()
self.tx_context: Optional[ydb.TxContext] = None
self.use_scan_query: bool = False
self.request_settings: ydb.BaseRequestSettings = ydb.BaseRequestSettings()

def cursor(self):
return self._cursor_class(
self.driver, self.session_pool, self.tx_mode, self.tx_context, self.use_scan_query, self.table_path_prefix
driver=self.driver,
session_pool=self.session_pool,
tx_mode=self.tx_mode,
tx_context=self.tx_context,
request_settings=self.request_settings,
use_scan_query=self.use_scan_query,
table_path_prefix=self.table_path_prefix,
)

def describe(self, table_path: str) -> ydb.TableDescription:
abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path)
cursor = self.cursor()
return cursor.describe_table(abs_table_path)

def check_exists(self, table_path: str) -> ydb.SchemeEntry:
def check_exists(self, table_path: str) -> bool:
abs_table_path = posixpath.join(self.database, self.table_path_prefix, table_path)
cursor = self.cursor()
return cursor.check_exists(abs_table_path)
Expand Down Expand Up @@ -124,6 +131,12 @@ def set_ydb_scan_query(self, value: bool) -> None:
def get_ydb_scan_query(self) -> bool:
return self.use_scan_query

def set_ydb_request_settings(self, value: ydb.BaseRequestSettings) -> None:
self.request_settings = value

def get_ydb_request_settings(self) -> ydb.BaseRequestSettings:
return self.request_settings

def begin(self):
self.tx_context = None
if self.interactive_transaction and not self.use_scan_query:
Expand Down
144 changes: 95 additions & 49 deletions ydb_sqlalchemy/dbapi/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,7 @@
import itertools
import posixpath
from collections.abc import AsyncIterator
from typing import (
Any,
Dict,
Generator,
List,
Mapping,
Optional,
Sequence,
Union,
)
from typing import Any, Dict, Generator, List, Mapping, Optional, Sequence, Union

import ydb
import ydb.aio
Expand All @@ -29,6 +20,7 @@
OperationalError,
ProgrammingError,
)
from .tracing import maybe_get_current_trace_id


def get_column_type(type_obj: Any) -> str:
Expand Down Expand Up @@ -87,35 +79,40 @@ def __init__(
driver: Union[ydb.Driver, ydb.aio.Driver],
session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool],
tx_mode: ydb.AbstractTransactionModeBuilder,
request_settings: ydb.BaseRequestSettings,
tx_context: Optional[ydb.BaseTxContext] = None,
use_scan_query: bool = False,
table_path_prefix: str = "",
):
self.driver = driver
self.session_pool = session_pool
self.tx_mode = tx_mode
self.request_settings = request_settings
self.tx_context = tx_context
self.use_scan_query = use_scan_query
self.root_directory = table_path_prefix
self.description = None
self.arraysize = 1
self.rows = None
self._rows_prefetched = None
self.root_directory = table_path_prefix

@_handle_ydb_errors
def describe_table(self, abs_table_path: str) -> ydb.TableDescription:
return self._retry_operation_in_pool(self._describe_table, abs_table_path)
settings = self._get_request_settings()
return self._retry_operation_in_pool(self._describe_table, abs_table_path, settings)

def check_exists(self, abs_table_path: str) -> bool:
settings = self._get_request_settings()
try:
self._retry_operation_in_pool(self._describe_path, abs_table_path)
self._retry_operation_in_pool(self._describe_path, abs_table_path, settings)
return True
except ydb.SchemeError:
return False

@_handle_ydb_errors
def get_table_names(self, abs_dir_path: str) -> List[str]:
directory: ydb.Directory = self._retry_operation_in_pool(self._list_directory, abs_dir_path)
settings = self._get_request_settings()
directory: ydb.Directory = self._retry_operation_in_pool(self._list_directory, abs_dir_path, settings)
result = []
for child in directory.children:
child_abs_path = posixpath.join(abs_dir_path, child.name)
Expand Down Expand Up @@ -180,79 +177,95 @@ def _make_data_query(
def _execute_scan_query(
self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None
) -> Generator[ydb.convert.ResultSet, None, None]:
settings = self._get_request_settings()
prepared_query = query
if isinstance(query, str) and parameters:
prepared_query: ydb.DataQuery = self._retry_operation_in_pool(self._prepare, query)
prepared_query: ydb.DataQuery = self._retry_operation_in_pool(self._prepare, query, settings)

if isinstance(query, str):
scan_query = ydb.ScanQuery(query, None)
else:
scan_query = ydb.ScanQuery(prepared_query.yql_text, prepared_query.parameters_types)

return self._execute_scan_query_in_driver(scan_query, parameters)
return self._execute_scan_query_in_driver(scan_query, parameters, settings)

@_handle_ydb_errors
def _execute_dml(
self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None
) -> ydb.convert.ResultSets:
settings = self._get_request_settings()
prepared_query = query
if isinstance(query, str) and parameters:
if self.tx_context:
prepared_query = self._run_operation_in_session(self._prepare, query)
prepared_query = self._run_operation_in_session(self._prepare, query, settings)
else:
prepared_query = self._retry_operation_in_pool(self._prepare, query)
prepared_query = self._retry_operation_in_pool(self._prepare, query, settings)

if self.tx_context:
return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters)
return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters, settings)

return self._retry_operation_in_pool(self._execute_in_session, self.tx_mode, prepared_query, parameters)
return self._retry_operation_in_pool(
self._execute_in_session, self.tx_mode, prepared_query, parameters, settings
)

@_handle_ydb_errors
def _execute_ddl(self, query: str) -> ydb.convert.ResultSets:
return self._retry_operation_in_pool(self._execute_scheme, query)
settings = self._get_request_settings()
return self._retry_operation_in_pool(self._execute_scheme, query, settings)

@staticmethod
def _execute_scheme(session: ydb.Session, query: str) -> ydb.convert.ResultSets:
return session.execute_scheme(query)
def _execute_scheme(
session: ydb.Session,
query: str,
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return session.execute_scheme(query, settings)

@staticmethod
def _describe_table(session: ydb.Session, abs_table_path: str) -> ydb.TableDescription:
return session.describe_table(abs_table_path)
def _describe_table(
session: ydb.Session, abs_table_path: str, settings: ydb.BaseRequestSettings
) -> ydb.TableDescription:
return session.describe_table(abs_table_path, settings)

@staticmethod
def _describe_path(session: ydb.Session, table_path: str) -> ydb.SchemeEntry:
return session._driver.scheme_client.describe_path(table_path)
def _describe_path(session: ydb.Session, table_path: str, settings: ydb.BaseRequestSettings) -> ydb.SchemeEntry:
return session._driver.scheme_client.describe_path(table_path, settings)

@staticmethod
def _list_directory(session: ydb.Session, abs_dir_path: str) -> ydb.Directory:
return session._driver.scheme_client.list_directory(abs_dir_path)
def _list_directory(session: ydb.Session, abs_dir_path: str, settings: ydb.BaseRequestSettings) -> ydb.Directory:
return session._driver.scheme_client.list_directory(abs_dir_path, settings)

@staticmethod
def _prepare(session: ydb.Session, query: str) -> ydb.DataQuery:
return session.prepare(query)
def _prepare(session: ydb.Session, query: str, settings: ydb.BaseRequestSettings) -> ydb.DataQuery:
return session.prepare(query, settings)

@staticmethod
def _execute_in_tx(
tx_context: ydb.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
tx_context: ydb.TxContext,
prepared_query: ydb.DataQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return tx_context.execute(prepared_query, parameters, commit_tx=False)
return tx_context.execute(prepared_query, parameters, commit_tx=False, settings=settings)

@staticmethod
def _execute_in_session(
session: ydb.Session,
tx_mode: ydb.AbstractTransactionModeBuilder,
prepared_query: ydb.DataQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)
return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True, settings=settings)

def _execute_scan_query_in_driver(
self,
scan_query: ydb.ScanQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> Generator[ydb.convert.ResultSet, None, None]:
chunk: ydb.ScanQueryResult
for chunk in self.driver.table_client.scan_query(scan_query, parameters):
for chunk in self.driver.table_client.scan_query(scan_query, parameters, settings):
yield chunk.result_set

def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs):
Expand Down Expand Up @@ -325,52 +338,85 @@ def close(self):
def rowcount(self):
return len(self._ensure_prefetched())

def _get_request_settings(self) -> ydb.BaseRequestSettings:
settings = self.request_settings.make_copy()

if self.request_settings.trace_id is None:
settings = settings.with_trace_id(maybe_get_current_trace_id())

return settings


class AsyncCursor(Cursor):
_await = staticmethod(util.await_only)

@staticmethod
async def _describe_table(session: ydb.aio.table.Session, abs_table_path: str) -> ydb.TableDescription:
return await session.describe_table(abs_table_path)
async def _describe_table(
session: ydb.aio.table.Session,
abs_table_path: str,
settings: ydb.BaseRequestSettings,
) -> ydb.TableDescription:
return await session.describe_table(abs_table_path, settings)

@staticmethod
async def _describe_path(session: ydb.aio.table.Session, abs_table_path: str) -> ydb.SchemeEntry:
return await session._driver.scheme_client.describe_path(abs_table_path)
async def _describe_path(
session: ydb.aio.table.Session,
abs_table_path: str,
settings: ydb.BaseRequestSettings,
) -> ydb.SchemeEntry:
return await session._driver.scheme_client.describe_path(abs_table_path, settings)

@staticmethod
async def _list_directory(session: ydb.aio.table.Session, abs_dir_path: str) -> ydb.Directory:
return await session._driver.scheme_client.list_directory(abs_dir_path)
async def _list_directory(
session: ydb.aio.table.Session,
abs_dir_path: str,
settings: ydb.BaseRequestSettings,
) -> ydb.Directory:
return await session._driver.scheme_client.list_directory(abs_dir_path, settings)

@staticmethod
async def _execute_scheme(session: ydb.aio.table.Session, query: str) -> ydb.convert.ResultSets:
return await session.execute_scheme(query)
async def _execute_scheme(
session: ydb.aio.table.Session,
query: str,
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return await session.execute_scheme(query, settings)

@staticmethod
async def _prepare(session: ydb.aio.table.Session, query: str) -> ydb.DataQuery:
return await session.prepare(query)
async def _prepare(
session: ydb.aio.table.Session,
query: str,
settings: ydb.BaseRequestSettings,
) -> ydb.DataQuery:
return await session.prepare(query, settings)

@staticmethod
async def _execute_in_tx(
tx_context: ydb.aio.table.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]]
tx_context: ydb.aio.table.TxContext,
prepared_query: ydb.DataQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return await tx_context.execute(prepared_query, parameters, commit_tx=False)
return await tx_context.execute(prepared_query, parameters, commit_tx=False, settings=settings)

@staticmethod
async def _execute_in_session(
session: ydb.aio.table.Session,
tx_mode: ydb.AbstractTransactionModeBuilder,
prepared_query: ydb.DataQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> ydb.convert.ResultSets:
return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True)
return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True, settings=settings)

def _execute_scan_query_in_driver(
self,
scan_query: ydb.ScanQuery,
parameters: Optional[Mapping[str, Any]],
settings: ydb.BaseRequestSettings,
) -> Generator[ydb.convert.ResultSet, None, None]:
iterator: AsyncIterator[ydb.ScanQueryResult] = self._await(
self.driver.table_client.scan_query(scan_query, parameters)
self.driver.table_client.scan_query(scan_query, parameters, settings)
)
while True:
try:
Expand Down
16 changes: 16 additions & 0 deletions ydb_sqlalchemy/dbapi/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import importlib.util
from typing import Optional


def maybe_get_current_trace_id() -> Optional[str]:
# Check if OpenTelemetry is available
if importlib.util.find_spec("opentelemetry"):
from opentelemetry import trace

current_span = trace.get_current_span()

if current_span.get_span_context().is_valid:
return format(current_span.get_span_context().trace_id, "032x")

# Return None if OpenTelemetry is not available or trace ID is invalid
return None
Loading
Loading