Skip to content

Commit

Permalink
Attempt to support sqlalchemy 1.4+
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvoleg committed Nov 13, 2024
1 parent 3c83092 commit da695ec
Show file tree
Hide file tree
Showing 9 changed files with 889 additions and 577 deletions.
32 changes: 16 additions & 16 deletions test/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_sa_crud(self, connection):
(5, "c"),
]

def test_cached_query(self, connection_no_trans: sa.Connection, connection: sa.Connection):
def test_cached_query(self, connection_no_trans, connection):
table = self.tables.test

with connection_no_trans.begin() as transaction:
Expand Down Expand Up @@ -263,7 +263,7 @@ def test_integer_types(self, connection):
result = connection.execute(stmt).fetchone()
assert result == (b"Uint8", b"Uint16", b"Uint32", b"Uint64", b"Int8", b"Int16", b"Int32", b"Int64")

def test_datetime_types(self, connection: sa.Connection):
def test_datetime_types(self, connection):
stmt = sa.Select(
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_datetime", datetime.datetime.now(), sa.DateTime))),
sa.func.FormatType(sa.func.TypeOf(sa.bindparam("p_DATETIME", datetime.datetime.now(), sa.DATETIME))),
Expand All @@ -273,7 +273,7 @@ def test_datetime_types(self, connection: sa.Connection):
result = connection.execute(stmt).fetchone()
assert result == (b"Timestamp", b"Datetime", b"Timestamp")

def test_datetime_types_timezone(self, connection: sa.Connection):
def test_datetime_types_timezone(self, connection):
table = self.tables.test_datetime_types
tzinfo = datetime.timezone(datetime.timedelta(hours=3, minutes=42))

Expand Down Expand Up @@ -476,7 +476,7 @@ def define_tables(cls, metadata: sa.MetaData):
Column("id", Integer, primary_key=True),
)

def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Connection):
def test_rollback(self, connection_no_trans, connection):
table = self.tables.test

connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
Expand All @@ -491,7 +491,7 @@ def test_rollback(self, connection_no_trans: sa.Connection, connection: sa.Conne
result = cursor.fetchall()
assert result == []

def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connection):
def test_commit(self, connection_no_trans, connection):
table = self.tables.test

connection_no_trans.execution_options(isolation_level=IsolationLevel.SERIALIZABLE)
Expand All @@ -507,7 +507,7 @@ def test_commit(self, connection_no_trans: sa.Connection, connection: sa.Connect

@pytest.mark.parametrize("isolation_level", (IsolationLevel.SERIALIZABLE, IsolationLevel.SNAPSHOT_READONLY))
def test_interactive_transaction(
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
self, connection_no_trans, connection, isolation_level
):
table = self.tables.test
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection
Expand Down Expand Up @@ -536,7 +536,7 @@ def test_interactive_transaction(
),
)
def test_not_interactive_transaction(
self, connection_no_trans: sa.Connection, connection: sa.Connection, isolation_level
self, connection_no_trans, connection, isolation_level
):
table = self.tables.test
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection
Expand Down Expand Up @@ -573,7 +573,7 @@ class IsolationSettings(NamedTuple):
IsolationLevel.SNAPSHOT_READONLY: IsolationSettings(ydb.QuerySnapshotReadOnly().name, True),
}

def test_connection_set(self, connection_no_trans: sa.Connection):
def test_connection_set(self, connection_no_trans):
dbapi_connection: dbapi.Connection = connection_no_trans.connection.dbapi_connection

for sa_isolation_level, ydb_isolation_settings in self.YDB_ISOLATION_SETTINGS_MAP.items():
Expand Down Expand Up @@ -861,7 +861,7 @@ def test_insert_in_name_and_field(self, connection):
class TestSecondaryIndex(TestBase):
__backend__ = True

def test_column_indexes(self, connection: sa.Connection, metadata: sa.MetaData):
def test_column_indexes(self, connection, metadata: sa.MetaData):
table = Table(
"test_column_indexes/table",
metadata,
Expand All @@ -884,7 +884,7 @@ def test_column_indexes(self, connection: sa.Connection, metadata: sa.MetaData):
index1 = indexes_map["ix_test_column_indexes_table_index_col2"]
assert index1.index_columns == ["index_col2"]

def test_async_index(self, connection: sa.Connection, metadata: sa.MetaData):
def test_async_index(self, connection, metadata: sa.MetaData):
table = Table(
"test_async_index/table",
metadata,
Expand All @@ -903,7 +903,7 @@ def test_async_index(self, connection: sa.Connection, metadata: sa.MetaData):
assert set(index.index_columns) == {"index_col1", "index_col2"}
# TODO: Check type after https://github.com/ydb-platform/ydb-python-sdk/issues/351

def test_cover_index(self, connection: sa.Connection, metadata: sa.MetaData):
def test_cover_index(self, connection, metadata: sa.MetaData):
table = Table(
"test_cover_index/table",
metadata,
Expand All @@ -922,7 +922,7 @@ def test_cover_index(self, connection: sa.Connection, metadata: sa.MetaData):
assert set(index.index_columns) == {"index_col1"}
# TODO: Check covered columns after https://github.com/ydb-platform/ydb-python-sdk/issues/409

def test_indexes_reflection(self, connection: sa.Connection, metadata: sa.MetaData):
def test_indexes_reflection(self, connection, metadata: sa.MetaData):
table = Table(
"test_indexes_reflection/table",
metadata,
Expand All @@ -948,7 +948,7 @@ def test_indexes_reflection(self, connection: sa.Connection, metadata: sa.MetaDa
"test_async_cover_index": {"index_col1"},
}

def test_index_simple_usage(self, connection: sa.Connection, metadata: sa.MetaData):
def test_index_simple_usage(self, connection, metadata: sa.MetaData):
persons = Table(
"test_index_simple_usage/persons",
metadata,
Expand Down Expand Up @@ -979,7 +979,7 @@ def test_index_simple_usage(self, connection: sa.Connection, metadata: sa.MetaDa
cursor = connection.execute(select_stmt)
assert cursor.scalar_one() == "Sarah Connor"

def test_index_with_join_usage(self, connection: sa.Connection, metadata: sa.MetaData):
def test_index_with_join_usage(self, connection, metadata: sa.MetaData):
persons = Table(
"test_index_with_join_usage/persons",
metadata,
Expand Down Expand Up @@ -1033,7 +1033,7 @@ def test_index_with_join_usage(self, connection: sa.Connection, metadata: sa.Met
cursor = connection.execute(select_stmt)
assert cursor.one() == ("Sarah Connor", "wanted")

def test_index_deletion(self, connection: sa.Connection, metadata: sa.MetaData):
def test_index_deletion(self, connection, metadata: sa.MetaData):
persons = Table(
"test_index_deletion/persons",
metadata,
Expand Down Expand Up @@ -1062,7 +1062,7 @@ def define_tables(cls, metadata: sa.MetaData):
Table("table", metadata, sa.Column("id", sa.Integer, primary_key=True))

@classmethod
def insert_data(cls, connection: sa.Connection):
def insert_data(cls, connection):
table = cls.tables["some_dir/nested_dir/table"]
root_table = cls.tables["table"]

Expand Down
50 changes: 25 additions & 25 deletions test/test_suite.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
from sqlalchemy.testing.suite.test_types import DateTimeTest as _DateTimeTest
from sqlalchemy.testing.suite.test_types import IntegerTest as _IntegerTest
from sqlalchemy.testing.suite.test_types import JSONTest as _JSONTest
from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest
# from sqlalchemy.testing.suite.test_types import NativeUUIDTest as _NativeUUIDTest
from sqlalchemy.testing.suite.test_types import NumericTest as _NumericTest
from sqlalchemy.testing.suite.test_types import StringTest as _StringTest
from sqlalchemy.testing.suite.test_types import (
Expand All @@ -78,7 +78,7 @@
TimestampMicrosecondsTest as _TimestampMicrosecondsTest,
)
from sqlalchemy.testing.suite.test_types import TimeTest as _TimeTest
from sqlalchemy.testing.suite.test_types import TrueDivTest as _TrueDivTest
# from sqlalchemy.testing.suite.test_types import TrueDivTest as _TrueDivTest

from ydb_sqlalchemy.sqlalchemy import types as ydb_sa_types

Expand Down Expand Up @@ -275,30 +275,30 @@ class BinaryTest(_BinaryTest):
pass


class TrueDivTest(_TrueDivTest):
@pytest.mark.skip("Unsupported builtin: FLOOR")
def test_floordiv_numeric(self, connection, left, right, expected):
pass
# class TrueDivTest(_TrueDivTest):
# @pytest.mark.skip("Unsupported builtin: FLOOR")
# def test_floordiv_numeric(self, connection, left, right, expected):
# pass

@pytest.mark.skip("Truediv unsupported for int")
def test_truediv_integer(self, connection, left, right, expected):
pass
# @pytest.mark.skip("Truediv unsupported for int")
# def test_truediv_integer(self, connection, left, right, expected):
# pass

@pytest.mark.skip("Truediv unsupported for int")
def test_truediv_integer_bound(self, connection):
pass
# @pytest.mark.skip("Truediv unsupported for int")
# def test_truediv_integer_bound(self, connection):
# pass

@pytest.mark.skip("Numeric is not Decimal")
def test_truediv_numeric(self):
# SqlAlchemy maybe eat Decimal and throw Double
pass
# @pytest.mark.skip("Numeric is not Decimal")
# def test_truediv_numeric(self):
# # SqlAlchemy maybe eat Decimal and throw Double
# pass

@testing.combinations(("6.25", "2.5", 2.5), argnames="left, right, expected")
def test_truediv_float(self, connection, left, right, expected):
eq_(
connection.scalar(select(literal_column(left, type_=sa.Float()) / literal_column(right, type_=sa.Float()))),
expected,
)
# @testing.combinations(("6.25", "2.5", 2.5), argnames="left, right, expected")
# def test_truediv_float(self, connection, left, right, expected):
# eq_(
# connection.scalar(select(literal_column(left, type_=sa.Float()) / literal_column(right, type_=sa.Float()))),
# expected,
# )


class ExistsTest(_ExistsTest):
Expand Down Expand Up @@ -539,9 +539,9 @@ def test_from_as_table(self, connection):
eq_(connection.execute(sa.select(table)).fetchall(), [(1,), (2,), (3,)])


@pytest.mark.skip("uuid unsupported for columns")
class NativeUUIDTest(_NativeUUIDTest):
pass
# @pytest.mark.skip("uuid unsupported for columns")
# class NativeUUIDTest(_NativeUUIDTest):
# pass


@pytest.mark.skip("unsupported Time data type")
Expand Down
Loading

0 comments on commit da695ec

Please sign in to comment.