Skip to content

Commit

Permalink
Attempt to support sqlalchemy 1.4+
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvoleg committed Nov 15, 2024
1 parent 3c83092 commit 6835615
Show file tree
Hide file tree
Showing 11 changed files with 767 additions and 586 deletions.
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
sqlalchemy >= 2.0.7, < 3.0.0
sqlalchemy >= 1.4.0, < 3.0.0
ydb >= 3.18.8
ydb-dbapi >= 0.1.1
ydb-dbapi >= 0.1.2
56 changes: 30 additions & 26 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
from ydb_sqlalchemy import sqlalchemy as ydb_sa
from ydb_sqlalchemy.sqlalchemy import types

if sa.__version__ >= "2.":
from sqlalchemy import NullPool
from sqlalchemy import QueuePool
else:
from sqlalchemy.pool import NullPool
from sqlalchemy.pool import QueuePool


def clear_sql(stm):
return stm.replace("\n", " ").replace(" ", " ").strip()
Expand Down Expand Up @@ -94,7 +101,7 @@ def test_sa_crud(self, connection):
(5, "c"),
]

def test_cached_query(self, connection_no_trans: sa.Connection, connection: sa.Connection):
def test_cached_query(self, connection_no_trans, connection):
table = self.tables.test

with connection_no_trans.begin() as transaction:
Expand Down Expand Up @@ -249,7 +256,7 @@ def test_primitive_types(self, connection):
assert row == (42, "Hello World!", 3.5, True)

def test_integer_types(self, connection):
stmt = sa.Select(
stmt = sa.select(
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint8", 8, types.UInt8))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint16", 16, types.UInt16))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_uint32", 32, types.UInt32))),
Expand All @@ -263,8 +270,8 @@ def test_integer_types(self, connection):
result = connection.execute(stmt).fetchone()
assert result == (b"Uint8", b"Uint16", b"Uint32", b"Uint64", b"Int8", b"Int16", b"Int32", b"Int64")

def test_datetime_types(self, connection: sa.Connection):
stmt = sa.Select(
def test_datetime_types(self, connection):
stmt = sa.select(
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_datetime", datetime.datetime.now(), sa.DateTime))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_DATETIME", datetime.datetime.now(), sa.DATETIME))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_TIMESTAMP", datetime.datetime.now(), sa.TIMESTAMP))),
Expand All @@ -273,7 +280,7 @@ def test_datetime_types(self, connection: sa.Connection):
result = connection.execute(stmt).fetchone()
assert result == (b"Timestamp", b"Datetime", b"Timestamp")

def test_datetime_types_timezone(self, connection: sa.Connection):
def test_datetime_types_timezone(self, connection):
table = self.tables.test_datetime_types
tzinfo = datetime.timezone(datetime.timedelta(hours=3, minutes=42))

Expand Down Expand Up @@ -476,7 +483,8 @@ def define_tables(cls, metadata: sa.MetaData):
Column("id", Integer, primary_key=True),
)

def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Connection):
@pytest.mark.skipif(sa.__version__ < "2.", reason="Something was different in SA<2, good to fix")
def test_rollback(self, connection_no_trans, connection):
table = self.tables.test

connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
Expand All @@ -491,7 +499,7 @@ def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Conne
result = cursor.fetchall()
assert result == []

def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connection):
def test_commit(self, connection_no_trans, connection):
table = self.tables.test

connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
Expand All @@ -506,9 +514,7 @@ def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connect
assert set(result) == {(3,), (4,)}

@pytest.mark.parametrize("isolation_level", (IsolationLevel.SERIALIZABLE, IsolationLevel.SNAPSHOT_READONLY))
def test_interactive_transaction(
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
):
def test_interactive_transaction(self, connection_no_trans, connection, isolation_level):
table = self.tables.test
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

Expand All @@ -535,9 +541,7 @@ def test_interactive_transaction(
IsolationLevel.AUTOCOMMIT,
),
)
def test_not_interactive_transaction(
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
):
def test_not_interactive_transaction(self, connection_no_trans, connection, isolation_level):
table = self.tables.test
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

Expand Down Expand Up @@ -573,7 +577,7 @@ class IsolationSettings(NamedTuple):
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.QuerySnapshotReadOnly().name, True),
}

def test_connection_set(self, connection_no_trans: sa.Connection):
def test_connection_set(self, connection_no_trans):
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

for sa_isolation_level, ydb_isolation_settings in self.YDB_ISOLATION_SETTINGS_MAP.items():
Expand Down Expand Up @@ -614,8 +618,8 @@ def ydb_pool(self, ydb_driver):
session_pool.stop()

def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
engine1 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=sa.QueuePool, connect_args={"ydb_session_pool": ydb_pool})
engine1 = sa.create_engine(config.db_url, poolclass=QueuePool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=QueuePool, connect_args={"ydb_session_pool": ydb_pool})

with engine1.connect() as conn1, engine2.connect() as conn2:
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
Expand All @@ -629,8 +633,8 @@ def test_sa_queue_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
assert not ydb_driver._stopped

def test_sa_null_pool_with_ydb_shared_session_pool(self, ydb_driver, ydb_pool):
engine1 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=sa.NullPool, connect_args={"ydb_session_pool": ydb_pool})
engine1 = sa.create_engine(config.db_url, poolclass=NullPool, connect_args={"ydb_session_pool": ydb_pool})
engine2 = sa.create_engine(config.db_url, poolclass=NullPool, connect_args={"ydb_session_pool": ydb_pool})

with engine1.connect() as conn1, engine2.connect() as conn2:
dbapi_conn1: dbapi.Connection = conn1.connection.dbapi_connection
Expand Down Expand Up @@ -861,7 +865,7 @@ def test_insert_in_name_and_field(self, connection):
class TestSecondaryIndex(TestBase):
__backend__ = True

def test_column_indexes(self, connection: sa.Connection, metadata: sa.MetaData):
def test_column_indexes(self, connection, metadata: sa.MetaData):
table = Table(
"test_column_indexes/table",
metadata,
Expand All @@ -884,7 +888,7 @@ def test_column_indexes(self, connection: sa.Connection, metadata: sa.MetaData):
index1 = indexes_map["ix_test_column_indexes_table_index_col2"]
assert index1.index_columns == ["index_col2"]

def test_async_index(self, connection: sa.Connection, metadata: sa.MetaData):
def test_async_index(self, connection, metadata: sa.MetaData):
table = Table(
"test_async_index/table",
metadata,
Expand All @@ -903,7 +907,7 @@ def test_async_index(self, connection: sa.Connection, metadata: sa.MetaData):
assert set(index.index_columns) == {"index_col1", "index_col2"}
# TODO: Check type after https://github.com/ydb-platform/ydb-python-sdk/issues/351

def test_cover_index(self, connection: sa.Connection, metadata: sa.MetaData):
def test_cover_index(self, connection, metadata: sa.MetaData):
table = Table(
"test_cover_index/table",
metadata,
Expand All @@ -922,7 +926,7 @@ def test_cover_index(self, connection: sa.Connection, metadata: sa.MetaData):
assert set(index.index_columns) == {"index_col1"}
# TODO: Check covered columns after https://github.com/ydb-platform/ydb-python-sdk/issues/409

def test_indexes_reflection(self, connection: sa.Connection, metadata: sa.MetaData):
def test_indexes_reflection(self, connection, metadata: sa.MetaData):
table = Table(
"test_indexes_reflection/table",
metadata,
Expand All @@ -948,7 +952,7 @@ def test_indexes_reflection(self, connection: sa.Connection, metadata: sa.MetaDa
"test_async_cover_index": {"index_col1"},
}

def test_index_simple_usage(self, connection: sa.Connection, metadata: sa.MetaData):
def test_index_simple_usage(self, connection, metadata: sa.MetaData):
persons = Table(
"test_index_simple_usage/persons",
metadata,
Expand Down Expand Up @@ -979,7 +983,7 @@ def test_index_simple_usage(self, connection: sa.Connection, metadata: sa.MetaDa
cursor = connection.execute(select_stmt)
assert cursor.scalar_one() == "Sarah Connor"

def test_index_with_join_usage(self, connection: sa.Connection, metadata: sa.MetaData):
def test_index_with_join_usage(self, connection, metadata: sa.MetaData):
persons = Table(
"test_index_with_join_usage/persons",
metadata,
Expand Down Expand Up @@ -1033,7 +1037,7 @@ def test_index_with_join_usage(self, connection: sa.Connection, metadata: sa.Met
cursor = connection.execute(select_stmt)
assert cursor.one() == ("Sarah Connor", "wanted")

def test_index_deletion(self, connection: sa.Connection, metadata: sa.MetaData):
def test_index_deletion(self, connection, metadata: sa.MetaData):
persons = Table(
"test_index_deletion/persons",
metadata,
Expand Down Expand Up @@ -1062,7 +1066,7 @@ def define_tables(cls, metadata: sa.MetaData):
Table("table", metadata, sa.Column("id", sa.Integer, primary_key=True))

@classmethod
def insert_data(cls, connection: sa.Connection):
def insert_data(cls, connection):
table = cls.tables["some_dir/nested_dir/table"]
root_table = cls.tables["table"]

Expand Down
60 changes: 35 additions & 25 deletions test/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from sqlalchemy.testing.suite.test_types import DateTimeTest as _DateTimeTest
from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest
from sqlalchemy.testing.suite.test_types import JSONTest as _JSONTest
from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest

from sqlalchemy.testing.suite.test_types import NumericTest as _NumericTest
from sqlalchemy.testing.suite.test_types import StringTest as _StringTest
from sqlalchemy.testing.suite.test_types import (
Expand All @@ -78,14 +78,16 @@
TimestampMicrosecondsTest as _TimestampMicrosecondsTest,
)
from sqlalchemy.testing.suite.test_types import TimeTest as _TimeTest
from sqlalchemy.testing.suite.test_types import TrueDivTest as _TrueDivTest

from ydb_sqlalchemy.sqlalchemy import types as ydb_sa_types

test_types_suite = sqlalchemy.testing.suite.test_types
col_creator = test_types_suite.Column


OLD_SA = sa.__version__ < "2."


def column_getter(*args, **kwargs):
col = col_creator(*args, **kwargs)
if col.name == "x":
Expand Down Expand Up @@ -275,30 +277,35 @@ class BinaryTest(_BinaryTest):
pass


class TrueDivTest(_TrueDivTest):
@pytest.mark.skip("Unsupported builtin: FLOOR")
def test_floordiv_numeric(self, connection, left, right, expected):
pass
if not OLD_SA:
from sqlalchemy.testing.suite.test_types import TrueDivTest as _TrueDivTest

@pytest.mark.skip("Truediv unsupported for int")
def test_truediv_integer(self, connection, left, right, expected):
pass
class TrueDivTest(_TrueDivTest):
@pytest.mark.skip("Unsupported builtin: FLOOR")
def test_floordiv_numeric(self, connection, left, right, expected):
pass

@pytest.mark.skip("Truediv unsupported for int")
def test_truediv_integer_bound(self, connection):
pass
@pytest.mark.skip("Truediv unsupported for int")
def test_truediv_integer(self, connection, left, right, expected):
pass

@pytest.mark.skip("Numeric is not Decimal")
def test_truediv_numeric(self):
# SqlAlchemy maybe eat Decimal and throw Double
pass
@pytest.mark.skip("Truediv unsupported for int")
def test_truediv_integer_bound(self, connection):
pass

@testing.combinations(("6.25", "2.5", 2.5), argnames="left, right, expected")
def test_truediv_float(self, connection, left, right, expected):
eq_(
connection.scalar(select(literal_column(left, type_=sa.Float()) / literal_column(right, type_=sa.Float()))),
expected,
)
@pytest.mark.skip("Numeric is not Decimal")
def test_truediv_numeric(self):
# SqlAlchemy maybe eat Decimal and throw Double
pass

@testing.combinations(("6.25", "2.5", 2.5), argnames="left, right, expected")
def test_truediv_float(self, connection, left, right, expected):
eq_(
connection.scalar(
select(literal_column(left, type_=sa.Float()) / literal_column(right, type_=sa.Float()))
),
expected,
)


class ExistsTest(_ExistsTest):
Expand Down Expand Up @@ -539,9 +546,12 @@ def test_from_as_table(self, connection):
eq_(connection.execute(sa.select(table)).fetchall(), [(1,), (2,), (3,)])


@pytest.mark.skip("uuid unsupported for columns")
class NativeUUIDTest(_NativeUUIDTest):
pass
if not OLD_SA:
from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest

@pytest.mark.skip("uuid unsupported for columns")
class NativeUUIDTest(_NativeUUIDTest):
pass


@pytest.mark.skip("unsupported Time data type")
Expand Down
1 change: 1 addition & 0 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,5 @@ max-line-length = 120
ignore=E203,W503
per-file-ignores =
ydb_sqlalchemy/__init__.py: F401
ydb_sqlalchemy/sqlalchemy/compiler/__init__.py: F401
exclude=*_pb2.py,*_grpc.py,.venv,.git,.tox,dist,doc,*egg,docs/*
Loading

0 comments on commit 6835615

Please sign in to comment.