Skip to content

Commit

Permalink
Make pandas/io/sql.py work with sqlalchemy 2.0 (#48576)
Browse files Browse the repository at this point in the history
  • Loading branch information
cdcadman authored Feb 9, 2023
1 parent 7ffc0ad commit c73dc7f
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 93 deletions.
2 changes: 1 addition & 1 deletion ci/deps/actions-310.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- tzdata>=2022a
- xarray
Expand Down
2 changes: 1 addition & 1 deletion ci/deps/actions-311.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- tzdata>=2022a
- xarray
Expand Down
2 changes: 1 addition & 1 deletion ci/deps/actions-38-downstream_compat.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- xarray
- xlrd
Expand Down
2 changes: 1 addition & 1 deletion ci/deps/actions-38.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- xarray
- xlrd
Expand Down
2 changes: 1 addition & 1 deletion ci/deps/actions-39.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- tzdata>=2022a
- xarray
Expand Down
2 changes: 1 addition & 1 deletion ci/deps/circle-38-arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- xarray
- xlrd
Expand Down
4 changes: 2 additions & 2 deletions doc/source/user_guide/io.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5868,15 +5868,15 @@ If you have an SQLAlchemy description of your database you can express where con
sa.Column("Col_3", sa.Boolean),
)
pd.read_sql(sa.select([data_table]).where(data_table.c.Col_3 is True), engine)
pd.read_sql(sa.select(data_table).where(data_table.c.Col_3 is True), engine)
You can combine SQLAlchemy expressions with parameters passed to :func:`read_sql` using :func:`sqlalchemy.bindparam`

.. ipython:: python
import datetime as dt
expr = sa.select([data_table]).where(data_table.c.Date > sa.bindparam("date"))
expr = sa.select(data_table).where(data_table.c.Date > sa.bindparam("date"))
pd.read_sql(expr, engine, params={"date": dt.datetime(2010, 10, 18)})
Expand Down
1 change: 1 addition & 0 deletions doc/source/whatsnew/v2.0.0.rst
Original file line number Diff line number Diff line change
Expand Up @@ -294,6 +294,7 @@ Other enhancements
- Added :meth:`DatetimeIndex.as_unit` and :meth:`TimedeltaIndex.as_unit` to convert to different resolutions; supported resolutions are "s", "ms", "us", and "ns" (:issue:`50616`)
- Added :meth:`Series.dt.unit` and :meth:`Series.dt.as_unit` to convert to different resolutions; supported resolutions are "s", "ms", "us", and "ns" (:issue:`51223`)
- Added new argument ``dtype`` to :func:`read_sql` to be consistent with :func:`read_sql_query` (:issue:`50797`)
- Added support for SQLAlchemy 2.0 (:issue:`40686`)
-

.. ---------------------------------------------------------------------------
Expand Down
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ dependencies:
- pyxlsb
- s3fs>=2021.08.0
- scipy
- sqlalchemy<1.4.46
- sqlalchemy
- tabulate
- tzdata>=2022a
- xarray
Expand Down
17 changes: 11 additions & 6 deletions pandas/core/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2711,7 +2711,7 @@ def to_sql(
library. Legacy support is provided for sqlite3.Connection objects. The user
is responsible for engine disposal and connection closure for the SQLAlchemy
connectable. See `here \
<https://docs.sqlalchemy.org/en/14/core/connections.html>`_.
<https://docs.sqlalchemy.org/en/20/core/connections.html>`_.
If passing a sqlalchemy.engine.Connection which is already in a transaction,
the transaction will not be committed. If passing a sqlite3.Connection,
it will not be possible to roll back the record insertion.
Expand Down Expand Up @@ -2761,7 +2761,7 @@ def to_sql(
attribute of ``sqlite3.Cursor`` or SQLAlchemy connectable which may not
reflect the exact number of written rows as stipulated in the
`sqlite3 <https://docs.python.org/3/library/sqlite3.html#sqlite3.Cursor.rowcount>`__ or
`SQLAlchemy <https://docs.sqlalchemy.org/en/14/core/connections.html#sqlalchemy.engine.BaseCursorResult.rowcount>`__.
`SQLAlchemy <https://docs.sqlalchemy.org/en/20/core/connections.html#sqlalchemy.engine.CursorResult.rowcount>`__.
.. versionadded:: 1.4.0
Expand Down Expand Up @@ -2805,7 +2805,9 @@ def to_sql(
>>> df.to_sql('users', con=engine)
3
>>> engine.execute("SELECT * FROM users").fetchall()
>>> from sqlalchemy import text
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM users")).fetchall()
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3')]
An `sqlalchemy.engine.Connection` can also be passed to `con`:
Expand All @@ -2821,7 +2823,8 @@ def to_sql(
>>> df2 = pd.DataFrame({'name' : ['User 6', 'User 7']})
>>> df2.to_sql('users', con=engine, if_exists='append')
2
>>> engine.execute("SELECT * FROM users").fetchall()
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM users")).fetchall()
[(0, 'User 1'), (1, 'User 2'), (2, 'User 3'),
(0, 'User 4'), (1, 'User 5'), (0, 'User 6'),
(1, 'User 7')]
Expand All @@ -2831,7 +2834,8 @@ def to_sql(
>>> df2.to_sql('users', con=engine, if_exists='replace',
... index_label='id')
2
>>> engine.execute("SELECT * FROM users").fetchall()
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM users")).fetchall()
[(0, 'User 6'), (1, 'User 7')]
Specify the dtype (especially useful for integers with missing values).
Expand All @@ -2851,7 +2855,8 @@ def to_sql(
... dtype={"A": Integer()})
3
>>> engine.execute("SELECT * FROM integers").fetchall()
>>> with engine.connect() as conn:
... conn.execute(text("SELECT * FROM integers")).fetchall()
[(1,), (None,), (2,)]
""" # noqa:E501
from pandas.io import sql
Expand Down
84 changes: 45 additions & 39 deletions pandas/io/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,23 +69,16 @@

if TYPE_CHECKING:
from sqlalchemy import Table
from sqlalchemy.sql.expression import (
Select,
TextClause,
)


# -----------------------------------------------------------------------------
# -- Helper functions


def _convert_params(sql, params):
"""Convert SQL and params args to DBAPI2.0 compliant format."""
args = [sql]
if params is not None:
if hasattr(params, "keys"): # test if params is a mapping
args += [params]
else:
args += [list(params)]
return args


def _process_parse_dates_argument(parse_dates):
"""Process parse_dates argument for read_sql functions"""
# handle non-list entries for parse_dates gracefully
Expand Down Expand Up @@ -224,8 +217,7 @@ def execute(sql, con, params=None):
if sqlalchemy is not None and isinstance(con, (str, sqlalchemy.engine.Engine)):
raise TypeError("pandas.io.sql.execute requires a connection") # GH50185
with pandasSQL_builder(con, need_transaction=True) as pandas_sql:
args = _convert_params(sql, params)
return pandas_sql.execute(*args)
return pandas_sql.execute(sql, params)


# -----------------------------------------------------------------------------
Expand Down Expand Up @@ -348,7 +340,7 @@ def read_sql_table(
else using_nullable_dtypes()
)

with pandasSQL_builder(con, schema=schema) as pandas_sql:
with pandasSQL_builder(con, schema=schema, need_transaction=True) as pandas_sql:
if not pandas_sql.has_table(table_name):
raise ValueError(f"Table {table_name} not found")

Expand Down Expand Up @@ -951,7 +943,8 @@ def sql_schema(self) -> str:
def _execute_create(self) -> None:
# Inserting table into database, add to MetaData object
self.table = self.table.to_metadata(self.pd_sql.meta)
self.table.create(bind=self.pd_sql.con)
with self.pd_sql.run_transaction():
self.table.create(bind=self.pd_sql.con)

def create(self) -> None:
if self.exists():
Expand Down Expand Up @@ -1221,7 +1214,7 @@ def _create_table_setup(self):

column_names_and_types = self._get_column_names_and_types(self._sqlalchemy_type)

columns = [
columns: list[Any] = [
Column(name, typ, index=is_index)
for name, typ, is_index in column_names_and_types
]
Expand Down Expand Up @@ -1451,7 +1444,7 @@ def to_sql(
pass

@abstractmethod
def execute(self, *args, **kwargs):
def execute(self, sql: str | Select | TextClause, params=None):
pass

@abstractmethod
Expand Down Expand Up @@ -1511,7 +1504,7 @@ def insert_records(

try:
return table.insert(chunksize=chunksize, method=method)
except exc.SQLAlchemyError as err:
except exc.StatementError as err:
# GH34431
# https://stackoverflow.com/a/67358288/6067848
msg = r"""(\(1054, "Unknown column 'inf(e0)?' in 'field list'"\))(?#
Expand Down Expand Up @@ -1579,13 +1572,18 @@ def __init__(
from sqlalchemy.engine import Engine
from sqlalchemy.schema import MetaData

# self.exit_stack cleans up the Engine and Connection and commits the
# transaction if any of those objects was created below.
# Cleanup happens either in self.__exit__ or at the end of the iterator
# returned by read_sql when chunksize is not None.
self.exit_stack = ExitStack()
if isinstance(con, str):
con = create_engine(con)
self.exit_stack.callback(con.dispose)
if isinstance(con, Engine):
con = self.exit_stack.enter_context(con.connect())
if need_transaction:
self.exit_stack.enter_context(con.begin())
if need_transaction and not con.in_transaction():
self.exit_stack.enter_context(con.begin())
self.con = con
self.meta = MetaData(schema=schema)
self.returns_generator = False
Expand All @@ -1596,11 +1594,18 @@ def __exit__(self, *args) -> None:

@contextmanager
def run_transaction(self):
yield self.con
if not self.con.in_transaction():
with self.con.begin():
yield self.con
else:
yield self.con

def execute(self, *args, **kwargs):
def execute(self, sql: str | Select | TextClause, params=None):
"""Simple passthrough to SQLAlchemy connectable"""
return self.con.execute(*args, **kwargs)
args = [] if params is None else [params]
if isinstance(sql, str):
return self.con.exec_driver_sql(sql, *args)
return self.con.execute(sql, *args)

def read_table(
self,
Expand Down Expand Up @@ -1780,9 +1785,7 @@ def read_query(
read_sql
"""
args = _convert_params(sql, params)

result = self.execute(*args)
result = self.execute(sql, params)
columns = result.keys()

if chunksize is not None:
Expand Down Expand Up @@ -1838,13 +1841,14 @@ def prep_table(
else:
dtype = cast(dict, dtype)

from sqlalchemy.types import (
TypeEngine,
to_instance,
)
from sqlalchemy.types import TypeEngine

for col, my_type in dtype.items():
if not isinstance(to_instance(my_type), TypeEngine):
if isinstance(my_type, type) and issubclass(my_type, TypeEngine):
pass
elif isinstance(my_type, TypeEngine):
pass
else:
raise ValueError(f"The type of {col} is not a SQLAlchemy type")

table = SQLTable(
Expand Down Expand Up @@ -2005,7 +2009,8 @@ def drop_table(self, table_name: str, schema: str | None = None) -> None:
schema = schema or self.meta.schema
if self.has_table(table_name, schema):
self.meta.reflect(bind=self.con, only=[table_name], schema=schema)
self.get_table(table_name, schema).drop(bind=self.con)
with self.run_transaction():
self.get_table(table_name, schema).drop(bind=self.con)
self.meta.clear()

def _create_sql_schema(
Expand Down Expand Up @@ -2238,21 +2243,24 @@ def run_transaction(self):
finally:
cur.close()

def execute(self, *args, **kwargs):
def execute(self, sql: str | Select | TextClause, params=None):
if not isinstance(sql, str):
raise TypeError("Query must be a string unless using sqlalchemy.")
args = [] if params is None else [params]
cur = self.con.cursor()
try:
cur.execute(*args, **kwargs)
cur.execute(sql, *args)
return cur
except Exception as exc:
try:
self.con.rollback()
except Exception as inner_exc: # pragma: no cover
ex = DatabaseError(
f"Execution failed on sql: {args[0]}\n{exc}\nunable to rollback"
f"Execution failed on sql: {sql}\n{exc}\nunable to rollback"
)
raise ex from inner_exc

ex = DatabaseError(f"Execution failed on sql '{args[0]}': {exc}")
ex = DatabaseError(f"Execution failed on sql '{sql}': {exc}")
raise ex from exc

@staticmethod
Expand Down Expand Up @@ -2305,9 +2313,7 @@ def read_query(
dtype: DtypeArg | None = None,
use_nullable_dtypes: bool = False,
) -> DataFrame | Iterator[DataFrame]:

args = _convert_params(sql, params)
cursor = self.execute(*args)
cursor = self.execute(sql, params)
columns = [col_desc[0] for col_desc in cursor.description]

if chunksize is not None:
Expand Down
Loading

0 comments on commit c73dc7f

Please sign in to comment.