Skip to content

Commit

Permalink
compat fixups
Browse files Browse the repository at this point in the history
  • Loading branch information
WillAyd committed Oct 3, 2023
1 parent 7223e63 commit c2cd90a
Showing 1 changed file with 42 additions and 10 deletions.
52 changes: 42 additions & 10 deletions pandas/tests/io/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def create_and_load_iris_postgresql(conn, iris_file: Path):
def create_and_load_iris(conn, iris_file: Path, dialect: str):
from sqlalchemy import insert

import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore")
iris = iris_table_metadata(dialect)

with iris_file.open(newline=None, encoding="utf-8") as csvfile:
Expand All @@ -223,7 +222,12 @@ def create_and_load_iris_view(conn):
cur = conn.cursor()
cur.execute(stmt)
else:
adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore")
if pa_version_under8p0:
adbc = None
else:
adbc = import_optional_dependency(
"adbc_driver_manager.dbapi", errors="ignore"
)
if adbc and isinstance(conn, adbc.Connection):
with conn.cursor() as cur:
cur.execute(stmt)
Expand Down Expand Up @@ -301,8 +305,12 @@ def create_and_load_types_sqlite3(conn, types_data: list[dict]):

def create_and_load_types_postgresql(conn, types_data: list[dict]):
# Boolean support not added until 0.8.0
adbc = import_optional_dependency("adbc_driver_manager")
if Version(adbc.__version__) < Version("0.8.0"):
if pa_version_under8p0:
adbc = None
else:
adbc = import_optional_dependency("adbc_driver_manager", errors="ignore")

if adbc and Version(adbc.__version__) < Version("0.8.0"):
bool_type = "INTEGER"
else:
bool_type = "BOOLEAN"
Expand Down Expand Up @@ -363,7 +371,10 @@ def check_iris_frame(frame: DataFrame):

def count_rows(conn, table_name: str):
stmt = f"SELECT count(*) AS count_1 FROM {table_name}"
adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore")
if pa_version_under8p0:
adbc = None
else:
adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore")
if isinstance(conn, sqlite3.Connection):
cur = conn.cursor()
return cur.execute(stmt).fetchone()[0]
Expand Down Expand Up @@ -495,7 +506,12 @@ def get_all_views(conn):
c = conn.execute("SELECT name FROM sqlite_master WHERE type='view'")
return [view[0] for view in c.fetchall()]
else:
adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore")
if pa_version_under8p0:
adbc = None
else:
adbc = import_optional_dependency(
"adbc_driver_manager.dbapi", errors="ignore"
)
if adbc and isinstance(conn, adbc.Connection):
results = []
info = conn.adbc_get_objects().read_all().to_pylist()
Expand All @@ -520,7 +536,13 @@ def get_all_tables(conn):
c = conn.execute("SELECT name FROM sqlite_master WHERE type='table'")
return [table[0] for table in c.fetchall()]
else:
adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore")
if pa_version_under8p0:
adbc = None
else:
adbc = import_optional_dependency(
"adbc_driver_manager.dbapi", errors="ignore"
)

if adbc and isinstance(conn, adbc.Connection):
results = []
info = conn.adbc_get_objects().read_all().to_pylist()
Expand All @@ -547,7 +569,12 @@ def drop_table(
conn.commit()

else:
adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore")
if pa_version_under8p0:
adbc = None
else:
adbc = import_optional_dependency(
"adbc_driver_manager.dbapi", errors="ignore"
)
if adbc and isinstance(conn, adbc.Connection):
with conn.cursor() as cur:
cur.execute(f'DROP TABLE IF EXISTS "{table_name}"')
Expand All @@ -564,7 +591,12 @@ def drop_view(
conn.execute(f"DROP VIEW IF EXISTS {sql._get_valid_sqlite_name(view_name)}")
conn.commit()
else:
adbc = import_optional_dependency("adbc_driver_manager.dbapi", errors="ignore")
if pa_version_under8p0:
adbc = None
else:
adbc = import_optional_dependency(
"adbc_driver_manager.dbapi", errors="ignore"
)
if adbc and isinstance(conn, adbc.Connection):
with conn.cursor() as cur:
cur.execute(f'DROP VIEW IF EXISTS "{view_name}"')
Expand Down Expand Up @@ -1796,7 +1828,7 @@ def test_api_timedelta(conn, request):
)

if "adbc" in conn_name:
exp_warning = None
exp_warning = FutureWarning # pyarrow warns is_sparse is deprecated
else:
exp_warning = UserWarning

Expand Down

0 comments on commit c2cd90a

Please sign in to comment.