Skip to content

Commit

Permalink
temp
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvoleg committed Nov 1, 2024
1 parent 3f03971 commit c485de0
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 52 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
sqlalchemy >= 2.0.7, < 3.0.0
ydb >= 3.18.6
/Users/ovcharuk/work/ydb-sqlalchemy/ydb_dbapi-0.0.1b1-py3-none-any.whl
ydb >= 3.18.7
ydb-dbapi==0.0.1b2
4 changes: 2 additions & 2 deletions test-requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pyyaml==5.3.1
greenlet
sqlalchemy==2.0.7
ydb >= 3.18.6
/Users/ovcharuk/work/ydb-sqlalchemy/ydb_dbapi-0.0.1b1-py3-none-any.whl
ydb >= 3.18.7
ydb-dbapi==0.0.1b2
requests<2.29
pytest==7.2.2
docker==6.0.1
Expand Down
92 changes: 46 additions & 46 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,49 +488,49 @@ def insert_data(cls, connection: sa.Connection):
# assert options_after_set["ydb_scan_query"]
# assert "ydb_scan_query" not in options_after_reset

def test_fetchmany(self, connection_no_trans: sa.Connection):
table = self.tables.test
stmt = sa.select(table).where(table.c.id % 2 == 0)
# def test_fetchmany(self, connection_no_trans: sa.Connection):
# table = self.tables.test
# stmt = sa.select(table).where(table.c.id % 2 == 0)

# connection_no_trans.execution_options(ydb_scan_query=True)
cursor = connection_no_trans.execute(stmt)
# # connection_no_trans.execution_options(ydb_scan_query=True)
# cursor = connection_no_trans.execute(stmt)

# assert cursor.cursor.use_scan_query
result = cursor.fetchmany(1000) # fetches only the first 5k rows
assert result == [(i,) for i in range(2000) if i % 2 == 0]
# # assert cursor.cursor.use_scan_query
# result = cursor.fetchmany(1000) # fetches only the first 5k rows
# assert result == [(i,) for i in range(2000) if i % 2 == 0]

def test_fetchall(self, connection_no_trans: sa.Connection):
table = self.tables.test
stmt = sa.select(table).where(table.c.id % 2 == 0)
# def test_fetchall(self, connection_no_trans: sa.Connection):
# table = self.tables.test
# stmt = sa.select(table).where(table.c.id % 2 == 0)

# connection_no_trans.execution_options(ydb_scan_query=True)
cursor = connection_no_trans.execute(stmt)
# # connection_no_trans.execution_options(ydb_scan_query=True)
# cursor = connection_no_trans.execute(stmt)

# assert cursor.cursor.use_scan_query
result = cursor.fetchall()
assert result == [(i,) for i in range(50000) if i % 2 == 0]
# # assert cursor.cursor.use_scan_query
# result = cursor.fetchall()
# assert result == [(i,) for i in range(50000) if i % 2 == 0]

def test_begin_does_nothing(self, connection_no_trans: sa.Connection):
table = self.tables.test
# connection_no_trans.execution_options(ydb_scan_query=True)
# def test_begin_does_nothing(self, connection_no_trans: sa.Connection):
# table = self.tables.test
# # connection_no_trans.execution_options(ydb_scan_query=True)

with connection_no_trans.begin():
cursor = connection_no_trans.execute(sa.select(table))
# with connection_no_trans.begin():
# cursor = connection_no_trans.execute(sa.select(table))

# assert cursor.cursor.use_scan_query
assert cursor.cursor.tx_context is None
# # assert cursor.cursor.use_scan_query
# assert cursor.cursor.tx_context is None

def test_engine_option(self):
table = self.tables.test
engine = self.bind.execution_options()
# def test_engine_option(self):
# table = self.tables.test
# engine = self.bind.execution_options()

with engine.begin() as connection:
cursor = connection.execute(sa.select(table))
# assert cursor.cursor.use_scan_query
# with engine.begin() as connection:
# cursor = connection.execute(sa.select(table))
# # assert cursor.cursor.use_scan_query

with engine.begin() as connection:
cursor = connection.execute(sa.select(table))
# assert cursor.cursor.use_scan_query
# with engine.begin() as connection:
# cursor = connection.execute(sa.select(table))
# # assert cursor.cursor.use_scan_query


class TestTransaction(TablesTest):
Expand Down Expand Up @@ -585,9 +585,9 @@ def test_interactive_transaction(

connection_no_trans.execution_options(isolation_level=isolation_level)
with connection_no_trans.begin():
cursor1 = connection_no_trans.execute(sa.select(table))
tx_id = dbapi_connection._tx_context.tx_id
assert tx_id is not None
cursor1 = connection_no_trans.execute(sa.select(table))
cursor2 = connection_no_trans.execute(sa.select(table))
assert dbapi_connection._tx_context.tx_id == tx_id

Expand Down Expand Up @@ -631,14 +631,14 @@ class IsolationSettings(NamedTuple):
interactive: bool

YDB_ISOLATION_SETTINGS_MAP = {
IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.SerializableReadWrite().name, False),
IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.SerializableReadWrite().name, True),
IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.OnlineReadOnly().name, False),
IsolationLevel.AUTOCOMMIT: IsolationSettings(ydb.QuerySerializableReadWrite().name, False),
IsolationLevel.SERIALIZABLE: IsolationSettings(ydb.QuerySerializableReadWrite().name, True),
IsolationLevel.ONLINE_READONLY: IsolationSettings(ydb.QueryOnlineReadOnly().name, False),
IsolationLevel.ONLINE_READONLY_INCONSISTENT: IsolationSettings(
ydb.OnlineReadOnly().with_allow_inconsistent_reads().name, False
ydb.QueryOnlineReadOnly().with_allow_inconsistent_reads().name, False
),
IsolationLevel.STALE_READONLY: IsolationSettings(ydb.StaleReadOnly().name, False),
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.SnapshotReadOnly().name, True),
IsolationLevel.STALE_READONLY: IsolationSettings(ydb.QueryStaleReadOnly().name, False),
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.QuerySnapshotReadOnly().name, True),
}

def test_connection_set(self, connection_no_trans: sa.Connection):
Expand Down Expand Up @@ -689,8 +689,8 @@ def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection

assert dbapi_conn1.session_pool is dbapi_conn2.session_pool
assert dbapi_conn1.driver is dbapi_conn2.driver
assert dbapi_conn1._session_pool is dbapi_conn2._session_pool
assert dbapi_conn1._driver is dbapi_conn2._driver

engine1.dispose()
engine2.dispose()
Expand All @@ -704,14 +704,13 @@ def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
dbapi_conn2: dbapi.Connection = conn2.connection.dbapi_connection

assert dbapi_conn1.session_pool is dbapi_conn2.session_pool
assert dbapi_conn1.driver is dbapi_conn2.driver
assert dbapi_conn1._session_pool is dbapi_conn2._session_pool
assert dbapi_conn1._driver is dbapi_conn2._driver

engine1.dispose()
engine2.dispose()
assert not ydb_driver._stopped


class TestAsyncEngine(TestEngine):
__only_on__ = "yql+ydb_async"

Expand All @@ -726,14 +725,15 @@ def ydb_driver(self):
finally:
loop.run_until_complete(driver.stop())

@pytest.mark.asyncio
@pytest.fixture(scope="class")
def ydb_pool(self, ydb_driver):
session_pool = ydb.aio.QuerySessionPool(ydb_driver, size=5)
loop = asyncio.get_event_loop()
session_pool = ydb.aio.QuerySessionPool(ydb_driver, size=5, loop=loop)

try:
yield session_pool
finally:
loop = asyncio.get_event_loop()
loop.run_until_complete(session_pool.stop())


Expand Down
5 changes: 4 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ ignore_errors = True
commands =
docker-compose up -d
python {toxinidir}/wait_container_ready.py
pytest -v test/test_core.py --dbdriver ydb --dbdriver ydb_async
pytest -v test/test_core.py::TestText::test_sa_text --dbdriver ydb --dbdriver ydb_async
; pytest -v test_dbapi
pytest -v ydb_sqlalchemy
docker-compose down
Expand Down Expand Up @@ -76,3 +76,6 @@ builtins = _
max-line-length = 120
ignore=E203,W503
exclude=*_pb2.py,*_grpc.py,.venv,.git,.tox,dist,doc,*egg,docs/*

[pytest]
asyncio_mode = auto
13 changes: 13 additions & 0 deletions ydb_sqlalchemy/driver/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ class AdaptedAsyncConnection(AdaptedConnection):
def __init__(self, connection: AsyncConnection):
self._connection: AsyncConnection = connection

@property
def _driver(self):
return self._connection._driver

@property
def _session_pool(self):
return self._connection._session_pool

@property
def _tx_context(self):
return self._connection._tx_context
Expand All @@ -20,6 +28,11 @@ def _tx_mode(self):
def connection(self):
return self._connection

@property
def interactive_transaction(self):
return self._connection.interactive_transaction


def cursor(self):
return AdaptedAsyncCursor(self._connection.cursor())

Expand Down
18 changes: 17 additions & 1 deletion ydb_sqlalchemy/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,10 @@ def get_bind_types(
for parameter_name, parameter_value in parameters_entry.items():
parameters_values[parameter_name].append(parameter_value)

print(f"Parameters_values: {parameters_values}")
parameter_types = {}
for bind_name in self.bind_names.values():
print(f"bind name {bind_name}")
bind = self.binds[bind_name]

if bind.literal_execute:
Expand All @@ -450,8 +452,12 @@ def get_bind_types(
if not post_compile_bind_values or None in post_compile_bind_values:
is_optional = True

print(f"post compile bind values {post_compile_bind_values}")

bind_type = self._guess_bound_variable_type_by_parameters(bind, post_compile_bind_values)

print(f"bind type {bind_type}")

if bind_type:
for post_compile_bind_name in post_compile_bind_names:
parameter_types[post_compile_bind_name] = YqlTypeCompiler(self.dialect).get_ydb_type(
Expand Down Expand Up @@ -816,6 +822,15 @@ def _add_declare_for_yql_stmt_vars_impl(self, statement, parameters_types):
)
return f"{declarations}\n{statement}"

def __merge_parameters_values_and_types(
self, values: Mapping[str, Any], types: Mapping[str, Any]
) -> Sequence[Mapping[str, ydb.TypedValue]]:
print(values)
print(types)
return {
key: ydb.TypedValue(values[key], types[key]) for key in values.keys()
}

def _make_ydb_operation(
self,
statement: str,
Expand All @@ -827,7 +842,8 @@ def _make_ydb_operation(

if not is_ddl and parameters:
parameters_types = context.compiled.get_bind_types(parameters)
parameters_types = {f"${k}": v for k, v in parameters_types.items()}
parameters = self.__merge_parameters_values_and_types(parameters, parameters_types)
# parameters_types = {f"${k}": v for k, v in parameters_types.items()}
statement, parameters = self._format_variables(statement, parameters, execute_many)
if self._add_declare_for_yql_stmt_vars:
statement = self._add_declare_for_yql_stmt_vars_impl(statement, parameters_types)
Expand Down

0 comments on commit c485de0

Please sign in to comment.