From c1af4f5ede3f6dfb1d1652c62afbcdc88b5eef91 Mon Sep 17 00:00:00 2001 From: eduardo naufel schettino Date: Sat, 29 Dec 2018 05:03:12 +0800 Subject: [PATCH] =?UTF-8?q?ENH:=20to=5Fsql()=20add=20parameter=20"method"?= =?UTF-8?q?=20to=20control=20insertions=20method=20(#8=E2=80=A6=20(#21401)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * ENH: to_sql() add parameter "method" to control insertions method (#8953) * ENH: to_sql() add parameter "method". Fix docstrings (#8953) * ENH: to_sql() add parameter "method". Improve docs based on reviews (#8953) * ENH: to_sql() add parameter "method". Fix unit-test (#8953) * doc clean-up * additional doc clean-up * use dict(zip()) directly * clean up merge * default --> None * Remove stray default * Remove method kwarg * change default to None * test copy insert snippit * print debug * index=False * Add reference to documentation --- doc/source/io.rst | 48 ++++++++++++++++++ doc/source/whatsnew/v0.24.0.rst | 1 + pandas/core/generic.py | 15 +++++- pandas/io/sql.py | 88 ++++++++++++++++++++++++++++----- pandas/tests/io/test_sql.py | 65 ++++++++++++++++++++++-- 5 files changed, 199 insertions(+), 18 deletions(-) diff --git a/doc/source/io.rst b/doc/source/io.rst index b22f52e448c0d..7230ff55f9a6c 100644 --- a/doc/source/io.rst +++ b/doc/source/io.rst @@ -4989,6 +4989,54 @@ with respect to the timezone. timezone aware or naive. When reading ``TIMESTAMP WITH TIME ZONE`` types, pandas will convert the data to UTC. +.. _io.sql.method: + +Insertion Method +++++++++++++++++ + +.. versionadded:: 0.24.0 + +The parameter ``method`` controls the SQL insertion clause used. +Possible values are: + +- ``None``: Uses standard SQL ``INSERT`` clause (one per row). +- ``'multi'``: Pass multiple values in a single ``INSERT`` clause. + It uses a *special* SQL syntax not supported by all backends. + This usually provides better performance for analytic databases + like *Presto* and *Redshift*, but has worse performance for + traditional SQL backend if the table contains many columns. + For more information check the SQLAlchemy `documention + `__. +- callable with signature ``(pd_table, conn, keys, data_iter)``: + This can be used to implement a more performant insertion method based on + specific backend dialect features. + +Example of a callable using PostgreSQL `COPY clause +`__:: + + # Alternative to_sql() *method* for DBs that support COPY FROM + import csv + from io import StringIO + + def psql_insert_copy(table, conn, keys, data_iter): + # gets a DBAPI connection that can provide a cursor + dbapi_conn = conn.connection + with dbapi_conn.cursor() as cur: + s_buf = StringIO() + writer = csv.writer(s_buf) + writer.writerows(data_iter) + s_buf.seek(0) + + columns = ', '.join('"{}"'.format(k) for k in keys) + if table.schema: + table_name = '{}.{}'.format(table.schema, table.name) + else: + table_name = table.name + + sql = 'COPY {} ({}) FROM STDIN WITH CSV'.format( + table_name, columns) + cur.copy_expert(sql=sql, file=s_buf) + Reading Tables '''''''''''''' diff --git a/doc/source/whatsnew/v0.24.0.rst b/doc/source/whatsnew/v0.24.0.rst index b2e52bbb3cb8f..0b0ba7aab49aa 100644 --- a/doc/source/whatsnew/v0.24.0.rst +++ b/doc/source/whatsnew/v0.24.0.rst @@ -377,6 +377,7 @@ Other Enhancements - :meth:`DataFrame.between_time` and :meth:`DataFrame.at_time` have gained the ``axis`` parameter (:issue:`8839`) - The ``scatter_matrix``, ``andrews_curves``, ``parallel_coordinates``, ``lag_plot``, ``autocorrelation_plot``, ``bootstrap_plot``, and ``radviz`` plots from the ``pandas.plotting`` module are now accessible from calling :meth:`DataFrame.plot` (:issue:`11978`) - :class:`IntervalIndex` has gained the :attr:`~IntervalIndex.is_overlapping` attribute to indicate if the ``IntervalIndex`` contains any overlapping intervals (:issue:`23309`) +- :func:`pandas.DataFrame.to_sql` has gained the ``method`` argument to control SQL insertion clause. See the :ref:`insertion method ` section in the documentation. (:issue:`8953`) .. _whatsnew_0240.api_breaking: diff --git a/pandas/core/generic.py b/pandas/core/generic.py index efb3f20202c42..a8d5e4aa772cc 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -2386,7 +2386,7 @@ def to_msgpack(self, path_or_buf=None, encoding='utf-8', **kwargs): **kwargs) def to_sql(self, name, con, schema=None, if_exists='fail', index=True, - index_label=None, chunksize=None, dtype=None): + index_label=None, chunksize=None, dtype=None, method=None): """ Write records stored in a DataFrame to a SQL database. @@ -2424,6 +2424,17 @@ def to_sql(self, name, con, schema=None, if_exists='fail', index=True, Specifying the datatype for columns. The keys should be the column names and the values should be the SQLAlchemy types or strings for the sqlite3 legacy mode. + method : {None, 'multi', callable}, default None + Controls the SQL insertion clause used: + + * None : Uses standard SQL ``INSERT`` clause (one per row). + * 'multi': Pass multiple values in a single ``INSERT`` clause. + * callable with signature ``(pd_table, conn, keys, data_iter)``. + + Details and a sample callable implementation can be found in the + section :ref:`insert method `. + + .. versionadded:: 0.24.0 Raises ------ @@ -2505,7 +2516,7 @@ def to_sql(self, name, con, schema=None, if_exists='fail', index=True, from pandas.io import sql sql.to_sql(self, name, con, schema=schema, if_exists=if_exists, index=index, index_label=index_label, chunksize=chunksize, - dtype=dtype) + dtype=dtype, method=method) def to_pickle(self, path, compression='infer', protocol=pkl.HIGHEST_PROTOCOL): diff --git a/pandas/io/sql.py b/pandas/io/sql.py index e54d29148c6d0..6093c6c3fd0fc 100644 --- a/pandas/io/sql.py +++ b/pandas/io/sql.py @@ -8,6 +8,7 @@ from contextlib import contextmanager from datetime import date, datetime, time +from functools import partial import re import warnings @@ -395,7 +396,7 @@ def read_sql(sql, con, index_col=None, coerce_float=True, params=None, def to_sql(frame, name, con, schema=None, if_exists='fail', index=True, - index_label=None, chunksize=None, dtype=None): + index_label=None, chunksize=None, dtype=None, method=None): """ Write records stored in a DataFrame to a SQL database. @@ -429,6 +430,17 @@ def to_sql(frame, name, con, schema=None, if_exists='fail', index=True, Optional specifying the datatype for columns. The SQL type should be a SQLAlchemy type, or a string for sqlite3 fallback connection. If all columns are of the same type, one single value can be used. + method : {None, 'multi', callable}, default None + Controls the SQL insertion clause used: + + - None : Uses standard SQL ``INSERT`` clause (one per row). + - 'multi': Pass multiple values in a single ``INSERT`` clause. + - callable with signature ``(pd_table, conn, keys, data_iter)``. + + Details and a sample callable implementation can be found in the + section :ref:`insert method `. + + .. versionadded:: 0.24.0 """ if if_exists not in ('fail', 'replace', 'append'): raise ValueError("'{0}' is not valid for if_exists".format(if_exists)) @@ -443,7 +455,7 @@ def to_sql(frame, name, con, schema=None, if_exists='fail', index=True, pandas_sql.to_sql(frame, name, if_exists=if_exists, index=index, index_label=index_label, schema=schema, - chunksize=chunksize, dtype=dtype) + chunksize=chunksize, dtype=dtype, method=method) def has_table(table_name, con, schema=None): @@ -568,8 +580,29 @@ def create(self): else: self._execute_create() - def insert_statement(self): - return self.table.insert() + def _execute_insert(self, conn, keys, data_iter): + """Execute SQL statement inserting data + + Parameters + ---------- + conn : sqlalchemy.engine.Engine or sqlalchemy.engine.Connection + keys : list of str + Column names + data_iter : generator of list + Each item contains a list of values to be inserted + """ + data = [dict(zip(keys, row)) for row in data_iter] + conn.execute(self.table.insert(), data) + + def _execute_insert_multi(self, conn, keys, data_iter): + """Alternative to _execute_insert for DBs support multivalue INSERT. + + Note: multi-value insert is usually faster for analytics DBs + and tables containing a few columns + but performance degrades quickly with increase of columns. + """ + data = [dict(zip(keys, row)) for row in data_iter] + conn.execute(self.table.insert(data)) def insert_data(self): if self.index is not None: @@ -612,11 +645,18 @@ def insert_data(self): return column_names, data_list - def _execute_insert(self, conn, keys, data_iter): - data = [dict(zip(keys, row)) for row in data_iter] - conn.execute(self.insert_statement(), data) + def insert(self, chunksize=None, method=None): + + # set insert method + if method is None: + exec_insert = self._execute_insert + elif method == 'multi': + exec_insert = self._execute_insert_multi + elif callable(method): + exec_insert = partial(method, self) + else: + raise ValueError('Invalid parameter `method`: {}'.format(method)) - def insert(self, chunksize=None): keys, data_list = self.insert_data() nrows = len(self.frame) @@ -639,7 +679,7 @@ def insert(self, chunksize=None): break chunk_iter = zip(*[arr[start_i:end_i] for arr in data_list]) - self._execute_insert(conn, keys, chunk_iter) + exec_insert(conn, keys, chunk_iter) def _query_iterator(self, result, chunksize, columns, coerce_float=True, parse_dates=None): @@ -1085,7 +1125,8 @@ def read_query(self, sql, index_col=None, coerce_float=True, read_sql = read_query def to_sql(self, frame, name, if_exists='fail', index=True, - index_label=None, schema=None, chunksize=None, dtype=None): + index_label=None, schema=None, chunksize=None, dtype=None, + method=None): """ Write records stored in a DataFrame to a SQL database. @@ -1115,7 +1156,17 @@ def to_sql(self, frame, name, if_exists='fail', index=True, Optional specifying the datatype for columns. The SQL type should be a SQLAlchemy type. If all columns are of the same type, one single value can be used. + method : {None', 'multi', callable}, default None + Controls the SQL insertion clause used: + + * None : Uses standard SQL ``INSERT`` clause (one per row). + * 'multi': Pass multiple values in a single ``INSERT`` clause. + * callable with signature ``(pd_table, conn, keys, data_iter)``. + + Details and a sample callable implementation can be found in the + section :ref:`insert method `. + .. versionadded:: 0.24.0 """ if dtype and not is_dict_like(dtype): dtype = {col_name: dtype for col_name in frame} @@ -1131,7 +1182,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, if_exists=if_exists, index_label=index_label, schema=schema, dtype=dtype) table.create() - table.insert(chunksize) + table.insert(chunksize, method=method) if (not name.isdigit() and not name.islower()): # check for potentially case sensitivity issues (GH7815) # Only check when name is not a number and name is not lower case @@ -1442,7 +1493,8 @@ def _fetchall_as_list(self, cur): return result def to_sql(self, frame, name, if_exists='fail', index=True, - index_label=None, schema=None, chunksize=None, dtype=None): + index_label=None, schema=None, chunksize=None, dtype=None, + method=None): """ Write records stored in a DataFrame to a SQL database. @@ -1471,7 +1523,17 @@ def to_sql(self, frame, name, if_exists='fail', index=True, Optional specifying the datatype for columns. The SQL type should be a string. If all columns are of the same type, one single value can be used. + method : {None, 'multi', callable}, default None + Controls the SQL insertion clause used: + + * None : Uses standard SQL ``INSERT`` clause (one per row). + * 'multi': Pass multiple values in a single ``INSERT`` clause. + * callable with signature ``(pd_table, conn, keys, data_iter)``. + + Details and a sample callable implementation can be found in the + section :ref:`insert method `. + .. versionadded:: 0.24.0 """ if dtype and not is_dict_like(dtype): dtype = {col_name: dtype for col_name in frame} @@ -1486,7 +1548,7 @@ def to_sql(self, frame, name, if_exists='fail', index=True, if_exists=if_exists, index_label=index_label, dtype=dtype) table.create() - table.insert(chunksize) + table.insert(chunksize, method) def has_table(self, name, schema=None): # TODO(wesm): unused? diff --git a/pandas/tests/io/test_sql.py b/pandas/tests/io/test_sql.py index eeeb55cb8e70c..c346103a70c98 100644 --- a/pandas/tests/io/test_sql.py +++ b/pandas/tests/io/test_sql.py @@ -375,12 +375,16 @@ def _read_sql_iris_named_parameter(self): iris_frame = self.pandasSQL.read_query(query, params=params) self._check_iris_loaded_frame(iris_frame) - def _to_sql(self): + def _to_sql(self, method=None): self.drop_table('test_frame1') - self.pandasSQL.to_sql(self.test_frame1, 'test_frame1') + self.pandasSQL.to_sql(self.test_frame1, 'test_frame1', method=method) assert self.pandasSQL.has_table('test_frame1') + num_entries = len(self.test_frame1) + num_rows = self._count_rows('test_frame1') + assert num_rows == num_entries + # Nuke table self.drop_table('test_frame1') @@ -434,6 +438,25 @@ def _to_sql_append(self): assert num_rows == num_entries self.drop_table('test_frame1') + def _to_sql_method_callable(self): + check = [] # used to double check function below is really being used + + def sample(pd_table, conn, keys, data_iter): + check.append(1) + data = [dict(zip(keys, row)) for row in data_iter] + conn.execute(pd_table.table.insert(), data) + self.drop_table('test_frame1') + + self.pandasSQL.to_sql(self.test_frame1, 'test_frame1', method=sample) + assert self.pandasSQL.has_table('test_frame1') + + assert check == [1] + num_entries = len(self.test_frame1) + num_rows = self._count_rows('test_frame1') + assert num_rows == num_entries + # Nuke table + self.drop_table('test_frame1') + def _roundtrip(self): self.drop_table('test_frame_roundtrip') self.pandasSQL.to_sql(self.test_frame1, 'test_frame_roundtrip') @@ -1193,7 +1216,7 @@ def setup_connect(self): pytest.skip( "Can't connect to {0} server".format(self.flavor)) - def test_aread_sql(self): + def test_read_sql(self): self._read_sql_iris() def test_read_sql_parameter(self): @@ -1217,6 +1240,12 @@ def test_to_sql_replace(self): def test_to_sql_append(self): self._to_sql_append() + def test_to_sql_method_multi(self): + self._to_sql(method='multi') + + def test_to_sql_method_callable(self): + self._to_sql_method_callable() + def test_create_table(self): temp_conn = self.connect() temp_frame = DataFrame( @@ -1930,6 +1959,36 @@ def test_schema_support(self): res2 = pdsql.read_table('test_schema_other2') tm.assert_frame_equal(res1, res2) + def test_copy_from_callable_insertion_method(self): + # GH 8953 + # Example in io.rst found under _io.sql.method + # not available in sqlite, mysql + def psql_insert_copy(table, conn, keys, data_iter): + # gets a DBAPI connection that can provide a cursor + dbapi_conn = conn.connection + with dbapi_conn.cursor() as cur: + s_buf = compat.StringIO() + writer = csv.writer(s_buf) + writer.writerows(data_iter) + s_buf.seek(0) + + columns = ', '.join('"{}"'.format(k) for k in keys) + if table.schema: + table_name = '{}.{}'.format(table.schema, table.name) + else: + table_name = table.name + + sql_query = 'COPY {} ({}) FROM STDIN WITH CSV'.format( + table_name, columns) + cur.copy_expert(sql=sql_query, file=s_buf) + + expected = DataFrame({'col1': [1, 2], 'col2': [0.1, 0.2], + 'col3': ['a', 'n']}) + expected.to_sql('test_copy_insert', self.conn, index=False, + method=psql_insert_copy) + result = sql.read_sql_table('test_copy_insert', self.conn) + tm.assert_frame_equal(result, expected) + @pytest.mark.single class TestMySQLAlchemy(_TestMySQLAlchemy, _TestSQLAlchemy):