From c73dc7f8be7d54f82fa1ff515335f167105ca022 Mon Sep 17 00:00:00 2001 From: Chuck Cadman <51368516+cdcadman@users.noreply.github.com> Date: Thu, 9 Feb 2023 09:35:08 -0800 Subject: [PATCH] Make pandas/io/sql.py work with sqlalchemy 2.0 (#48576) --- ci/deps/actions-310.yaml | 2 +- ci/deps/actions-311.yaml | 2 +- ci/deps/actions-38-downstream_compat.yaml | 2 +- ci/deps/actions-38.yaml | 2 +- ci/deps/actions-39.yaml | 2 +- ci/deps/circle-38-arm64.yaml | 2 +- doc/source/user_guide/io.rst | 4 +- doc/source/whatsnew/v2.0.0.rst | 1 + environment.yml | 2 +- pandas/core/generic.py | 17 ++- pandas/io/sql.py | 84 ++++++------ pandas/tests/io/test_sql.py | 159 ++++++++++++++++------ requirements-dev.txt | 2 +- 13 files changed, 188 insertions(+), 93 deletions(-) diff --git a/ci/deps/actions-310.yaml b/ci/deps/actions-310.yaml index 25032ed1c76b0..24676856f9fad 100644 --- a/ci/deps/actions-310.yaml +++ b/ci/deps/actions-310.yaml @@ -48,7 +48,7 @@ dependencies: - pyxlsb - s3fs>=2021.08.0 - scipy - - sqlalchemy<1.4.46 + - sqlalchemy - tabulate - tzdata>=2022a - xarray diff --git a/ci/deps/actions-311.yaml b/ci/deps/actions-311.yaml index aef97c232e940..0e90f2f87198f 100644 --- a/ci/deps/actions-311.yaml +++ b/ci/deps/actions-311.yaml @@ -48,7 +48,7 @@ dependencies: - pyxlsb - s3fs>=2021.08.0 - scipy - - sqlalchemy<1.4.46 + - sqlalchemy - tabulate - tzdata>=2022a - xarray diff --git a/ci/deps/actions-38-downstream_compat.yaml b/ci/deps/actions-38-downstream_compat.yaml index 1de392a9cc277..8f6fe60403b18 100644 --- a/ci/deps/actions-38-downstream_compat.yaml +++ b/ci/deps/actions-38-downstream_compat.yaml @@ -48,7 +48,7 @@ dependencies: - pyxlsb - s3fs>=2021.08.0 - scipy - - sqlalchemy<1.4.46 + - sqlalchemy - tabulate - xarray - xlrd diff --git a/ci/deps/actions-38.yaml b/ci/deps/actions-38.yaml index 803b0bdbff793..ea9e3fea365a0 100644 --- a/ci/deps/actions-38.yaml +++ b/ci/deps/actions-38.yaml @@ -48,7 +48,7 @@ dependencies: - pyxlsb - s3fs>=2021.08.0 - scipy - - sqlalchemy<1.4.46 + - sqlalchemy - tabulate - xarray - xlrd diff --git a/ci/deps/actions-39.yaml b/ci/deps/actions-39.yaml index 5ce5681aa9e21..80cf1c1e539b7 100644 --- a/ci/deps/actions-39.yaml +++ b/ci/deps/actions-39.yaml @@ -48,7 +48,7 @@ dependencies: - pyxlsb - s3fs>=2021.08.0 - scipy - - sqlalchemy<1.4.46 + - sqlalchemy - tabulate - tzdata>=2022a - xarray diff --git a/ci/deps/circle-38-arm64.yaml b/ci/deps/circle-38-arm64.yaml index 7dcb84dc8874c..e8fc1a1459943 100644 --- a/ci/deps/circle-38-arm64.yaml +++ b/ci/deps/circle-38-arm64.yaml @@ -49,7 +49,7 @@ dependencies: - pyxlsb - s3fs>=2021.08.0 - scipy - - sqlalchemy<1.4.46 + - sqlalchemy - tabulate - xarray - xlrd diff --git a/doc/source/user_guide/io.rst b/doc/source/user_guide/io.rst index 50aabad2d0bd3..1c3cdd9f4cffd 100644 --- a/doc/source/user_guide/io.rst +++ b/doc/source/user_guide/io.rst @@ -5868,7 +5868,7 @@ If you have an SQLAlchemy description of your database you can express where con sa.Column("Col_3", sa.Boolean), ) - pd.read_sql(sa.select([data_table]).where(data_table.c.Col_3 is True), engine) + pd.read_sql(sa.select(data_table).where(data_table.c.Col_3 is True), engine) You can combine SQLAlchemy expressions with parameters passed to :func:`read_sql` using :func:`sqlalchemy.bindparam` @@ -5876,7 +5876,7 @@ You can combine SQLAlchemy expressions with parameters passed to :func:`read_sql import datetime as dt - expr = sa.select([data_table]).where(data_table.c.Date > sa.bindparam("date")) + expr = sa.select(data_table).where(data_table.c.Date > sa.bindparam("date")) pd.read_sql(expr, engine, params={"date": dt.datetime(2010, 10, 18)}) diff --git a/doc/source/whatsnew/v2.0.0.rst b/doc/source/whatsnew/v2.0.0.rst index 3c9c249e0f6ea..d067c4df0ada6 100644 --- a/doc/source/whatsnew/v2.0.0.rst +++ b/doc/source/whatsnew/v2.0.0.rst @@ -294,6 +294,7 @@ Other enhancements - Added :meth:`DatetimeIndex.as_unit` and :meth:`TimedeltaIndex.as_unit` to convert to different resolutions; supported resolutions are "s", "ms", "us", and "ns" (:issue:`50616`) - Added :meth:`Series.dt.unit` and :meth:`Series.dt.as_unit` to convert to different resolutions; supported resolutions are "s", "ms", "us", and "ns" (:issue:`51223`) - Added new argument ``dtype`` to :func:`read_sql` to be consistent with :func:`read_sql_query` (:issue:`50797`) +- Added support for SQLAlchemy 2.0 (:issue:`40686`) - .. --------------------------------------------------------------------------- diff --git a/environment.yml b/environment.yml index 05251001d8e86..9169cbf08b45d 100644 --- a/environment.yml +++ b/environment.yml @@ -51,7 +51,7 @@ dependencies: - pyxlsb - s3fs>=2021.08.0 - scipy - - sqlalchemy<1.4.46 + - sqlalchemy - tabulate - tzdata>=2022a - xarray diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 272e97abf35a9..2bf28b275d1df 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2711,7 +2711,7 @@ def to_sql( library. Legacy support is provided for sqlite3.Connection objects. The user is responsible for engine disposal and connection closure for the SQLAlchemy connectable. See `here \ - `_. + `_. If passing a sqlalchemy.engine.Connection which is already in a transaction, the transaction will not be committed. If passing a sqlite3.Connection, it will not be possible to roll back the record insertion. @@ -2761,7 +2761,7 @@ def to_sql( attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not reflect the exact number of written rows as stipulated in the `sqlite3 `__ or - `SQLAlchemy `__. + `SQLAlchemy `__. .. versionadded:: 1.4.0 @@ -2805,7 +2805,9 @@ def to_sql( >>> df.to_sql('users', con=engine) 3 - >>> engine.execute("SELECT * FROM users").fetchall() + >>> from sqlalchemy import text + >>> with engine.connect() as conn: + ... conn.execute(text("SELECT * FROM users")).fetchall() [(0, 'User 1'), (1, 'User 2'), (2, 'User 3')] An `sqlalchemy.engine.Connection` can also be passed to `con`: @@ -2821,7 +2823,8 @@ def to_sql( >>> df2 = pd.DataFrame({'name' : ['User 6', 'User 7']}) >>> df2.to_sql('users', con=engine, if_exists='append') 2 - >>> engine.execute("SELECT * FROM users").fetchall() + >>> with engine.connect() as conn: + ... conn.execute(text("SELECT * FROM users")).fetchall() [(0, 'User 1'), (1, 'User 2'), (2, 'User 3'), (0, 'User 4'), (1, 'User 5'), (0, 'User 6'), (1, 'User 7')] @@ -2831,7 +2834,8 @@ def to_sql( >>> df2.to_sql('users', con=engine, if_exists='replace', ... index_label='id') 2 - >>> engine.execute("SELECT * FROM users").fetchall() + >>> with engine.connect() as conn: + ... conn.execute(text("SELECT * FROM users")).fetchall() [(0, 'User 6'), (1, 'User 7')] Specify the dtype (especially useful for integers with missing values). @@ -2851,7 +2855,8 @@ def to_sql( ... dtype={"A": Integer()}) 3 - >>> engine.execute("SELECT * FROM integers").fetchall() + >>> with engine.connect() as conn: + ... conn.execute(text("SELECT * FROM integers")).fetchall() [(1,), (None,), (2,)] """ # noqa:E501 from pandas.io import sql diff --git a/pandas/io/sql.py b/pandas/io/sql.py index d88decc8601f0..6764c0578bf7a 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -69,23 +69,16 @@ if TYPE_CHECKING: from sqlalchemy import Table + from sqlalchemy.sql.expression import ( + Select, + TextClause, + ) # ----------------------------------------------------------------------------- # -- Helper functions -def _convert_params(sql, params): - """Convert SQL and params args to DBAPI2.0 compliant format.""" - args = [sql] - if params is not None: - if hasattr(params, "keys"): # test if params is a mapping - args += [params] - else: - args += [list(params)] - return args - - def _process_parse_dates_argument(parse_dates): """Process parse_dates argument for read_sql functions""" # handle non-list entries for parse_dates gracefully @@ -224,8 +217,7 @@ def execute(sql, con, params=None): if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Engine)): raise TypeError("pandas.io.sql.execute requires a connection") # GH50185 with pandasSQL_builder(con, need_transaction=True) as pandas_sql: - args = _convert_params(sql, params) - return pandas_sql.execute(*args) + return pandas_sql.execute(sql, params) # ----------------------------------------------------------------------------- @@ -348,7 +340,7 @@ def read_sql_table( else using_nullable_dtypes() ) - with pandasSQL_builder(con, schema=schema) as pandas_sql: + with pandasSQL_builder(con, schema=schema, need_transaction=True) as pandas_sql: if not pandas_sql.has_table(table_name): raise ValueError(f"Table {table_name} not found") @@ -951,7 +943,8 @@ def sql_schema(self) -> str: def _execute_create(self) -> None: # Inserting table into database, add to MetaData object self.table = self.table.to_metadata(self.pd_sql.meta) - self.table.create(bind=self.pd_sql.con) + with self.pd_sql.run_transaction(): + self.table.create(bind=self.pd_sql.con) def create(self) -> None: if self.exists(): @@ -1221,7 +1214,7 @@ def _create_table_setup(self): column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type) - columns = [ + columns: list[Any] = [ Column(name, typ, index=is_index) for name, typ, is_index in column_names_and_types ] @@ -1451,7 +1444,7 @@ def to_sql( pass @abstractmethod - def execute(self, *args, **kwargs): + def execute(self, sql: str | Select | TextClause, params=None): pass @abstractmethod @@ -1511,7 +1504,7 @@ def insert_records( try: return table.insert(chunksize=chunksize, method=method) - except exc.SQLAlchemyError as err: + except exc.StatementError as err: # GH34431 # https://stackoverflow.com/a/67358288/6067848 msg = r"""(\(1054, "Unknown column 'inf(e0)?' in 'field list'"\))(?# @@ -1579,13 +1572,18 @@ def __init__( from sqlalchemy.engine import Engine from sqlalchemy.schema import MetaData + # self.exit_stack cleans up the Engine and Connection and commits the + # transaction if any of those objects was created below. + # Cleanup happens either in self.__exit__ or at the end of the iterator + # returned by read_sql when chunksize is not None. self.exit_stack = ExitStack() if isinstance(con, str): con = create_engine(con) + self.exit_stack.callback(con.dispose) if isinstance(con, Engine): con = self.exit_stack.enter_context(con.connect()) - if need_transaction: - self.exit_stack.enter_context(con.begin()) + if need_transaction and not con.in_transaction(): + self.exit_stack.enter_context(con.begin()) self.con = con self.meta = MetaData(schema=schema) self.returns_generator = False @@ -1596,11 +1594,18 @@ def __exit__(self, *args) -> None: @contextmanager def run_transaction(self): - yield self.con + if not self.con.in_transaction(): + with self.con.begin(): + yield self.con + else: + yield self.con - def execute(self, *args, **kwargs): + def execute(self, sql: str | Select | TextClause, params=None): """Simple passthrough to SQLAlchemy connectable""" - return self.con.execute(*args, **kwargs) + args = [] if params is None else [params] + if isinstance(sql, str): + return self.con.exec_driver_sql(sql, *args) + return self.con.execute(sql, *args) def read_table( self, @@ -1780,9 +1785,7 @@ def read_query( read_sql """ - args = _convert_params(sql, params) - - result = self.execute(*args) + result = self.execute(sql, params) columns = result.keys() if chunksize is not None: @@ -1838,13 +1841,14 @@ def prep_table( else: dtype = cast(dict, dtype) - from sqlalchemy.types import ( - TypeEngine, - to_instance, - ) + from sqlalchemy.types import TypeEngine for col, my_type in dtype.items(): - if not isinstance(to_instance(my_type), TypeEngine): + if isinstance(my_type, type) and issubclass(my_type, TypeEngine): + pass + elif isinstance(my_type, TypeEngine): + pass + else: raise ValueError(f"The type of {col} is not a SQLAlchemy type") table = SQLTable( @@ -2005,7 +2009,8 @@ def drop_table(self, table_name: str, schema: str | None = None) -> None: schema = schema or self.meta.schema if self.has_table(table_name, schema): self.meta.reflect(bind=self.con, only=[table_name], schema=schema) - self.get_table(table_name, schema).drop(bind=self.con) + with self.run_transaction(): + self.get_table(table_name, schema).drop(bind=self.con) self.meta.clear() def _create_sql_schema( @@ -2238,21 +2243,24 @@ def run_transaction(self): finally: cur.close() - def execute(self, *args, **kwargs): + def execute(self, sql: str | Select | TextClause, params=None): + if not isinstance(sql, str): + raise TypeError("Query must be a string unless using sqlalchemy.") + args = [] if params is None else [params] cur = self.con.cursor() try: - cur.execute(*args, **kwargs) + cur.execute(sql, *args) return cur except Exception as exc: try: self.con.rollback() except Exception as inner_exc: # pragma: no cover ex = DatabaseError( - f"Execution failed on sql: {args[0]}\n{exc}\nunable to rollback" + f"Execution failed on sql: {sql}\n{exc}\nunable to rollback" ) raise ex from inner_exc - ex = DatabaseError(f"Execution failed on sql '{args[0]}': {exc}") + ex = DatabaseError(f"Execution failed on sql '{sql}': {exc}") raise ex from exc @staticmethod @@ -2305,9 +2313,7 @@ def read_query( dtype: DtypeArg | None = None, use_nullable_dtypes: bool = False, ) -> DataFrame | Iterator[DataFrame]: - - args = _convert_params(sql, params) - cursor = self.execute(*args) + cursor = self.execute(sql, params) columns = [col_desc[0] for col_desc in cursor.description] if chunksize is not None: diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index 3ccc3bdd94f7e..7207ff2356a03 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -149,8 +149,6 @@ def create_and_load_iris(conn, iris_file: Path, dialect: str): from sqlalchemy.engine import Engine iris = iris_table_metadata(dialect) - iris.drop(conn, checkfirst=True) - iris.create(bind=conn) with iris_file.open(newline=None) as csvfile: reader = csv.reader(csvfile) @@ -160,9 +158,14 @@ def create_and_load_iris(conn, iris_file: Path, dialect: str): if isinstance(conn, Engine): with conn.connect() as conn: with conn.begin(): + iris.drop(conn, checkfirst=True) + iris.create(bind=conn) conn.execute(stmt) else: - conn.execute(stmt) + with conn.begin(): + iris.drop(conn, checkfirst=True) + iris.create(bind=conn) + conn.execute(stmt) def create_and_load_iris_view(conn): @@ -180,7 +183,8 @@ def create_and_load_iris_view(conn): with conn.begin(): conn.execute(stmt) else: - conn.execute(stmt) + with conn.begin(): + conn.execute(stmt) def types_table_metadata(dialect: str): @@ -243,16 +247,19 @@ def create_and_load_types(conn, types_data: list[dict], dialect: str): from sqlalchemy.engine import Engine types = types_table_metadata(dialect) - types.drop(conn, checkfirst=True) - types.create(bind=conn) stmt = insert(types).values(types_data) if isinstance(conn, Engine): with conn.connect() as conn: with conn.begin(): + types.drop(conn, checkfirst=True) + types.create(bind=conn) conn.execute(stmt) else: - conn.execute(stmt) + with conn.begin(): + types.drop(conn, checkfirst=True) + types.create(bind=conn) + conn.execute(stmt) def check_iris_frame(frame: DataFrame): @@ -269,25 +276,21 @@ def count_rows(conn, table_name: str): cur = conn.cursor() return cur.execute(stmt).fetchone()[0] else: - from sqlalchemy import ( - create_engine, - text, - ) + from sqlalchemy import create_engine from sqlalchemy.engine import Engine - stmt = text(stmt) if isinstance(conn, str): try: engine = create_engine(conn) with engine.connect() as conn: - return conn.execute(stmt).scalar_one() + return conn.exec_driver_sql(stmt).scalar_one() finally: engine.dispose() elif isinstance(conn, Engine): with conn.connect() as conn: - return conn.execute(stmt).scalar_one() + return conn.exec_driver_sql(stmt).scalar_one() else: - return conn.execute(stmt).scalar_one() + return conn.exec_driver_sql(stmt).scalar_one() @pytest.fixture @@ -417,7 +420,8 @@ def mysql_pymysql_engine(iris_path, types_data): @pytest.fixture def mysql_pymysql_conn(mysql_pymysql_engine): - yield mysql_pymysql_engine.connect() + with mysql_pymysql_engine.connect() as conn: + yield conn @pytest.fixture @@ -443,7 +447,8 @@ def postgresql_psycopg2_engine(iris_path, types_data): @pytest.fixture def postgresql_psycopg2_conn(postgresql_psycopg2_engine): - yield postgresql_psycopg2_engine.connect() + with postgresql_psycopg2_engine.connect() as conn: + yield conn @pytest.fixture @@ -463,7 +468,8 @@ def sqlite_engine(sqlite_str): @pytest.fixture def sqlite_conn(sqlite_engine): - yield sqlite_engine.connect() + with sqlite_engine.connect() as conn: + yield conn @pytest.fixture @@ -483,7 +489,8 @@ def sqlite_iris_engine(sqlite_engine, iris_path): @pytest.fixture def sqlite_iris_conn(sqlite_iris_engine): - yield sqlite_iris_engine.connect() + with sqlite_iris_engine.connect() as conn: + yield conn @pytest.fixture @@ -533,12 +540,20 @@ def sqlite_buildin_iris(sqlite_buildin, iris_path): all_connectable_iris = sqlalchemy_connectable_iris + ["sqlite_buildin_iris"] +@pytest.mark.db +@pytest.mark.parametrize("conn", all_connectable) +def test_dataframe_to_sql(conn, test_frame1, request): + # GH 51086 if conn is sqlite_engine + conn = request.getfixturevalue(conn) + test_frame1.to_sql("test", conn, if_exists="append", index=False) + + @pytest.mark.db @pytest.mark.parametrize("conn", all_connectable) @pytest.mark.parametrize("method", [None, "multi"]) def test_to_sql(conn, method, test_frame1, request): conn = request.getfixturevalue(conn) - with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: pandasSQL.to_sql(test_frame1, "test_frame", method=method) assert pandasSQL.has_table("test_frame") assert count_rows(conn, "test_frame") == len(test_frame1) @@ -549,7 +564,7 @@ def test_to_sql(conn, method, test_frame1, request): @pytest.mark.parametrize("mode, num_row_coef", [("replace", 1), ("append", 2)]) def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request): conn = request.getfixturevalue(conn) - with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") pandasSQL.to_sql(test_frame1, "test_frame", if_exists=mode) assert pandasSQL.has_table("test_frame") @@ -560,7 +575,7 @@ def test_to_sql_exist(conn, mode, num_row_coef, test_frame1, request): @pytest.mark.parametrize("conn", all_connectable) def test_to_sql_exist_fail(conn, test_frame1, request): conn = request.getfixturevalue(conn) - with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: pandasSQL.to_sql(test_frame1, "test_frame", if_exists="fail") assert pandasSQL.has_table("test_frame") @@ -595,9 +610,45 @@ def test_read_iris_query_chunksize(conn, request): assert "SepalWidth" in iris_frame.columns +@pytest.mark.db +@pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) +def test_read_iris_query_expression_with_parameter(conn, request): + conn = request.getfixturevalue(conn) + from sqlalchemy import ( + MetaData, + Table, + create_engine, + select, + ) + + metadata = MetaData() + autoload_con = create_engine(conn) if isinstance(conn, str) else conn + iris = Table("iris", metadata, autoload_with=autoload_con) + iris_frame = read_sql_query( + select(iris), conn, params={"name": "Iris-setosa", "length": 5.1} + ) + check_iris_frame(iris_frame) + if isinstance(conn, str): + autoload_con.dispose() + + +@pytest.mark.db +@pytest.mark.parametrize("conn", all_connectable_iris) +def test_read_iris_query_string_with_parameter(conn, request): + for db, query in SQL_STRINGS["read_parameters"].items(): + if db in conn: + break + else: + raise KeyError(f"No part of {conn} found in SQL_STRINGS['read_parameters']") + conn = request.getfixturevalue(conn) + iris_frame = read_sql_query(query, conn, params=("Iris-setosa", 5.1)) + check_iris_frame(iris_frame) + + @pytest.mark.db @pytest.mark.parametrize("conn", sqlalchemy_connectable_iris) def test_read_iris_table(conn, request): + # GH 51015 if conn = sqlite_iris_str conn = request.getfixturevalue(conn) iris_frame = read_sql_table("iris", conn) check_iris_frame(iris_frame) @@ -627,7 +678,7 @@ def sample(pd_table, conn, keys, data_iter): data = [dict(zip(keys, row)) for row in data_iter] conn.execute(pd_table.table.insert(), data) - with pandasSQL_builder(conn) as pandasSQL: + with pandasSQL_builder(conn, need_transaction=True) as pandasSQL: pandasSQL.to_sql(test_frame1, "test_frame", method=sample) assert pandasSQL.has_table("test_frame") assert check == [1] @@ -680,7 +731,8 @@ def test_read_procedure(conn, request): with engine_conn.begin(): engine_conn.execute(proc) else: - conn.execute(proc) + with conn.begin(): + conn.execute(proc) res1 = sql.read_sql_query("CALL get_testdb();", conn) tm.assert_frame_equal(df, res1) @@ -762,6 +814,8 @@ def teardown_method(self): pass else: with conn: + for view in self._get_all_views(conn): + self.drop_view(view, conn) for tbl in self._get_all_tables(conn): self.drop_table(tbl, conn) @@ -778,6 +832,14 @@ def _get_all_tables(self, conn): c = conn.execute("SELECT name FROM sqlite_master WHERE type='table'") return [table[0] for table in c.fetchall()] + def drop_view(self, view_name, conn): + conn.execute(f"DROP VIEW IF EXISTS {sql._get_valid_sqlite_name(view_name)}") + conn.commit() + + def _get_all_views(self, conn): + c = conn.execute("SELECT name FROM sqlite_master WHERE type='view'") + return [view[0] for view in c.fetchall()] + class SQLAlchemyMixIn(MixInBase): @classmethod @@ -788,6 +850,8 @@ def connect(self): return self.engine.connect() def drop_table(self, table_name, conn): + if conn.in_transaction(): + conn.get_transaction().rollback() with conn.begin(): sql.SQLDatabase(conn).drop_table(table_name) @@ -796,6 +860,20 @@ def _get_all_tables(self, conn): return inspect(conn).get_table_names() + def drop_view(self, view_name, conn): + quoted_view = conn.engine.dialect.identifier_preparer.quote_identifier( + view_name + ) + if conn.in_transaction(): + conn.get_transaction().rollback() + with conn.begin(): + conn.exec_driver_sql(f"DROP VIEW IF EXISTS {quoted_view}") + + def _get_all_views(self, conn): + from sqlalchemy import inspect + + return inspect(conn).get_view_names() + class PandasSQLTest: """ @@ -822,7 +900,7 @@ def load_types_data(self, types_data): def _read_sql_iris_parameter(self): query = SQL_STRINGS["read_parameters"][self.flavor] - params = ["Iris-setosa", 5.1] + params = ("Iris-setosa", 5.1) iris_frame = self.pandasSQL.read_query(query, params=params) check_iris_frame(iris_frame) @@ -951,8 +1029,6 @@ class _TestSQLApi(PandasSQLTest): @pytest.fixture(autouse=True) def setup_method(self, iris_path, types_data): self.conn = self.connect() - if not isinstance(self.conn, sqlite3.Connection): - self.conn.begin() self.load_iris_data(iris_path) self.load_types_data(types_data) self.load_test_data_and_sql() @@ -1448,7 +1524,8 @@ def test_not_reflect_all_tables(self): with conn.begin(): conn.execute(query) else: - self.conn.execute(query) + with self.conn.begin(): + self.conn.execute(query) with tm.assert_produces_warning(None): sql.read_sql_table("other_table", self.conn) @@ -1698,7 +1775,6 @@ def setup_class(cls): def setup_method(self, iris_path, types_data): try: self.conn = self.engine.connect() - self.conn.begin() self.pandasSQL = sql.SQLDatabase(self.conn) except sqlalchemy.exc.OperationalError: pytest.skip(f"Can't connect to {self.flavor} server") @@ -1729,8 +1805,8 @@ def test_create_table(self): temp_frame = DataFrame( {"one": [1.0, 2.0, 3.0, 4.0], "two": [4.0, 3.0, 2.0, 1.0]} ) - pandasSQL = sql.SQLDatabase(temp_conn) - assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 + with sql.SQLDatabase(temp_conn, need_transaction=True) as pandasSQL: + assert pandasSQL.to_sql(temp_frame, "temp_frame") == 4 insp = inspect(temp_conn) assert insp.has_table("temp_frame") @@ -1749,6 +1825,10 @@ def test_drop_table(self): assert insp.has_table("temp_frame") pandasSQL.drop_table("temp_frame") + try: + insp.clear_cache() # needed with SQLAlchemy 2.0, unavailable prior + except AttributeError: + pass assert not insp.has_table("temp_frame") def test_roundtrip(self, test_frame1): @@ -2098,7 +2178,6 @@ def _get_index_columns(self, tbl_name): def test_to_sql_save_index(self): self._to_sql_save_index() - @pytest.mark.xfail(reason="Nested transactions rollbacks don't work with Pandas") def test_transactions(self): self._transaction_test() @@ -2120,7 +2199,8 @@ def test_get_schema_create_table(self, test_frame3): with conn.begin(): conn.execute(create_sql) else: - self.conn.execute(create_sql) + with self.conn.begin(): + self.conn.execute(create_sql) returned_df = sql.read_sql_table(tbl, self.conn) tm.assert_frame_equal(returned_df, blank_test_df, check_index_type=False) self.drop_table(tbl, self.conn) @@ -2586,7 +2666,8 @@ class Test(BaseModel): id = Column(Integer, primary_key=True) string_column = Column(String(50)) - BaseModel.metadata.create_all(self.conn) + with self.conn.begin(): + BaseModel.metadata.create_all(self.conn) Session = sessionmaker(bind=self.conn) with Session() as session: df = DataFrame({"id": [0, 1], "string_column": ["hello", "world"]}) @@ -2680,8 +2761,9 @@ def test_schema_support(self): df = DataFrame({"col1": [1, 2], "col2": [0.1, 0.2], "col3": ["a", "n"]}) # create a schema - self.conn.execute("DROP SCHEMA IF EXISTS other CASCADE;") - self.conn.execute("CREATE SCHEMA other;") + with self.conn.begin(): + self.conn.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;") + self.conn.exec_driver_sql("CREATE SCHEMA other;") # write dataframe to different schema's assert df.to_sql("test_schema_public", self.conn, index=False) == 2 @@ -2713,8 +2795,9 @@ def test_schema_support(self): # different if_exists options # create a schema - self.conn.execute("DROP SCHEMA IF EXISTS other CASCADE;") - self.conn.execute("CREATE SCHEMA other;") + with self.conn.begin(): + self.conn.exec_driver_sql("DROP SCHEMA IF EXISTS other CASCADE;") + self.conn.exec_driver_sql("CREATE SCHEMA other;") # write dataframe with different if_exists options assert ( diff --git a/requirements-dev.txt b/requirements-dev.txt index 3783c7c2aeb5f..b6992a7266600 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -40,7 +40,7 @@ python-snappy pyxlsb s3fs>=2021.08.0 scipy -sqlalchemy<1.4.46 +sqlalchemy tabulate tzdata>=2022.1 xarray