From 575196557e46a21d28e3f8b63193fac2423cf981 Mon Sep 17 00:00:00 2001 From: Keiron Pizzey Date: Sat, 8 Jul 2017 13:30:16 +0100 Subject: [PATCH 1/2] ENH - Modify Dataframe.select_dtypes to accept scalar values This commit related to GH16855. It allows the Dataframe.select_dtypes function to accept scalar values as well as list-like values. As such, it should maintain backwards compatibility. --- doc/source/basics.rst | 4 ---- doc/source/style.ipynb | 2 +- doc/source/whatsnew/v0.21.0.txt | 1 + pandas/core/frame.py | 21 ++++++++++++--------- pandas/tests/frame/test_dtypes.py | 26 ++++++++++++++------------ 5 files changed, 28 insertions(+), 26 deletions(-) diff --git a/doc/source/basics.rst b/doc/source/basics.rst index 134cc5106015b..d8b1602fb104d 100644 --- a/doc/source/basics.rst +++ b/doc/source/basics.rst @@ -2229,7 +2229,3 @@ All numpy dtypes are subclasses of ``numpy.generic``: Pandas also defines the types ``category``, and ``datetime64[ns, tz]``, which are not integrated into the normal numpy hierarchy and wont show up with the above function. - -.. note:: - - The ``include`` and ``exclude`` parameters must be non-string sequences. diff --git a/doc/source/style.ipynb b/doc/source/style.ipynb index 4eeda491426b1..c250787785e14 100644 --- a/doc/source/style.ipynb +++ b/doc/source/style.ipynb @@ -935,7 +935,7 @@ "\n", "*Experimental: This is a new feature and still under development. We'll be adding features and possibly making breaking changes in future releases. We'd love to hear your feedback.*\n", "\n", - "Some support is available for exporting styled `DataFrames` to Excel worksheets using the `OpenPyXL` engine. CSS2.2 properties handled include:\n", + "Some support is available for exporting styled `DataFrames` to Excel worksheets using the `OpenPyXL` engine. CSS2.2 properties handled include:\n", "\n", "- `background-color`\n", "- `border-style`, `border-width`, `border-color` and their {`top`, `right`, `bottom`, `left` variants}\n", diff --git a/doc/source/whatsnew/v0.21.0.txt b/doc/source/whatsnew/v0.21.0.txt index d5cc3d6ddca8e..6968bbebc836c 100644 --- a/doc/source/whatsnew/v0.21.0.txt +++ b/doc/source/whatsnew/v0.21.0.txt @@ -39,6 +39,7 @@ Other Enhancements - :func:`read_feather` has gained the ``nthreads`` parameter for multi-threaded operations (:issue:`16359`) - :func:`DataFrame.clip()` and :func:`Series.clip()` have gained an ``inplace`` argument. (:issue:`15388`) - :func:`crosstab` has gained a ``margins_name`` parameter to define the name of the row / column that will contain the totals when ``margins=True``. (:issue:`15972`) +- :func:`Dataframe.select_dtypes` now accepts scalar values for include/exclude as well as list-like. (:issue:`16855`) .. _whatsnew_0210.api_breaking: diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 80cdebc24c39d..9e7a1ec805ab7 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -2285,9 +2285,9 @@ def select_dtypes(self, include=None, exclude=None): Parameters ---------- - include, exclude : list-like - A list of dtypes or strings to be included/excluded. You must pass - in a non-empty sequence for at least one of these. + include, exclude : scalar or list-like + A selection of dtypes or strings to be included/excluded. At least + one of these parameters must be supplied. Raises ------ @@ -2295,8 +2295,6 @@ def select_dtypes(self, include=None, exclude=None): * If both of ``include`` and ``exclude`` are empty * If ``include`` and ``exclude`` have overlapping elements * If any kind of string dtype is passed in. - TypeError - * If either of ``include`` or ``exclude`` is not a sequence Returns ------- @@ -2348,10 +2346,15 @@ def select_dtypes(self, include=None, exclude=None): 4 True 5 False """ - include, exclude = include or (), exclude or () - if not (is_list_like(include) and is_list_like(exclude)): - raise TypeError('include and exclude must both be non-string' - ' sequences') + + # GH16855 - If either include or exclude is a non-None scalar then + # convert to a tuple of length 1 and continue. + # This allows, for example, df.select_dtypes(include='object'). + if not is_list_like(include): + include = (include,) if include is not None else () + if not is_list_like(exclude): + exclude = (exclude,) if exclude is not None else () + selection = tuple(map(frozenset, (include, exclude))) if not any(selection): diff --git a/pandas/tests/frame/test_dtypes.py b/pandas/tests/frame/test_dtypes.py index 335b76ff2aade..1ef0025ad43fc 100644 --- a/pandas/tests/frame/test_dtypes.py +++ b/pandas/tests/frame/test_dtypes.py @@ -149,6 +149,10 @@ def test_select_dtypes_include(self): ei = df[['k']] assert_frame_equal(ri, ei) + ri = df.select_dtypes(include='category') + ei = df[['f']] + assert_frame_equal(ri, ei) + pytest.raises(NotImplementedError, lambda: df.select_dtypes(include=['period'])) @@ -162,6 +166,10 @@ def test_select_dtypes_exclude(self): ee = df[['a', 'e']] assert_frame_equal(re, ee) + re = df.select_dtypes(exclude=np.number) + ee = df[['a', 'e']] + assert_frame_equal(re, ee) + def test_select_dtypes_exclude_include(self): df = DataFrame({'a': list('abc'), 'b': list(range(1, 4)), @@ -181,6 +189,12 @@ def test_select_dtypes_exclude_include(self): e = df[['b', 'e']] assert_frame_equal(r, e) + exclude = 'datetime' + include = 'bool' + r = df.select_dtypes(include=include, exclude=exclude) + e = df[['e']] + assert_frame_equal(r, e) + def test_select_dtypes_not_an_attr_but_still_valid_dtype(self): df = DataFrame({'a': list('abc'), 'b': list(range(1, 4)), @@ -205,18 +219,6 @@ def test_select_dtypes_empty(self): 'must be nonempty'): df.select_dtypes() - def test_select_dtypes_raises_on_string(self): - df = DataFrame({'a': list('abc'), 'b': list(range(1, 4))}) - with tm.assert_raises_regex(TypeError, 'include and exclude ' - '.+ non-'): - df.select_dtypes(include='object') - with tm.assert_raises_regex(TypeError, 'include and exclude ' - '.+ non-'): - df.select_dtypes(exclude='object') - with tm.assert_raises_regex(TypeError, 'include and exclude ' - '.+ non-'): - df.select_dtypes(include=int, exclude='object') - def test_select_dtypes_bad_datetime64(self): df = DataFrame({'a': list('abc'), 'b': list(range(1, 4)), From 2f9ae6e550d6461d89654180961d28d02fe5a7f5 Mon Sep 17 00:00:00 2001 From: Keiron Pizzey Date: Sat, 8 Jul 2017 16:54:12 +0100 Subject: [PATCH 2/2] Changes according to comments Includes: - Adding another example to the docstring - Separated out the tests into explicit functions to cover the different scenario. Doesn't implement the change in logic suggested as it doesn't work as suggested. Happy to discuss other logic changes as necessary. --- pandas/core/frame.py | 11 ++- pandas/tests/frame/test_dtypes.py | 130 +++++++++++++++++++++++++----- 2 files changed, 118 insertions(+), 23 deletions(-) diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 9e7a1ec805ab7..6559fc4c24ce2 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -2329,6 +2329,14 @@ def select_dtypes(self, include=None, exclude=None): 3 0.0764 False 2 4 -0.9703 True 1 5 -1.2094 False 2 + >>> df.select_dtypes(include='bool') + c + 0 True + 1 False + 2 True + 3 False + 4 True + 5 False >>> df.select_dtypes(include=['float64']) c 0 1 @@ -2347,9 +2355,6 @@ def select_dtypes(self, include=None, exclude=None): 5 False """ - # GH16855 - If either include or exclude is a non-None scalar then - # convert to a tuple of length 1 and continue. - # This allows, for example, df.select_dtypes(include='object'). if not is_list_like(include): include = (include,) if include is not None else () if not is_list_like(exclude): diff --git a/pandas/tests/frame/test_dtypes.py b/pandas/tests/frame/test_dtypes.py index 1ef0025ad43fc..065580d56a683 100644 --- a/pandas/tests/frame/test_dtypes.py +++ b/pandas/tests/frame/test_dtypes.py @@ -104,7 +104,7 @@ def test_dtypes_are_correct_after_column_slice(self): ('b', np.float_), ('c', np.float_)]))) - def test_select_dtypes_include(self): + def test_select_dtypes_include_using_list_like(self): df = DataFrame({'a': list('abc'), 'b': list(range(1, 4)), 'c': np.arange(3, 6).astype('u1'), @@ -145,18 +145,10 @@ def test_select_dtypes_include(self): ei = df[['h', 'i']] assert_frame_equal(ri, ei) - ri = df.select_dtypes(include=['timedelta']) - ei = df[['k']] - assert_frame_equal(ri, ei) - - ri = df.select_dtypes(include='category') - ei = df[['f']] - assert_frame_equal(ri, ei) - pytest.raises(NotImplementedError, lambda: df.select_dtypes(include=['period'])) - def test_select_dtypes_exclude(self): + def test_select_dtypes_exclude_using_list_like(self): df = DataFrame({'a': list('abc'), 'b': list(range(1, 4)), 'c': np.arange(3, 6).astype('u1'), @@ -166,11 +158,7 @@ def test_select_dtypes_exclude(self): ee = df[['a', 'e']] assert_frame_equal(re, ee) - re = df.select_dtypes(exclude=np.number) - ee = df[['a', 'e']] - assert_frame_equal(re, ee) - - def test_select_dtypes_exclude_include(self): + def test_select_dtypes_exclude_include_using_list_like(self): df = DataFrame({'a': list('abc'), 'b': list(range(1, 4)), 'c': np.arange(3, 6).astype('u1'), @@ -189,11 +177,113 @@ def test_select_dtypes_exclude_include(self): e = df[['b', 'e']] assert_frame_equal(r, e) - exclude = 'datetime' - include = 'bool' - r = df.select_dtypes(include=include, exclude=exclude) - e = df[['e']] - assert_frame_equal(r, e) + def test_select_dtypes_include_using_scalars(self): + df = DataFrame({'a': list('abc'), + 'b': list(range(1, 4)), + 'c': np.arange(3, 6).astype('u1'), + 'd': np.arange(4.0, 7.0, dtype='float64'), + 'e': [True, False, True], + 'f': pd.Categorical(list('abc')), + 'g': pd.date_range('20130101', periods=3), + 'h': pd.date_range('20130101', periods=3, + tz='US/Eastern'), + 'i': pd.date_range('20130101', periods=3, + tz='CET'), + 'j': pd.period_range('2013-01', periods=3, + freq='M'), + 'k': pd.timedelta_range('1 day', periods=3)}) + + ri = df.select_dtypes(include=np.number) + ei = df[['b', 'c', 'd', 'k']] + assert_frame_equal(ri, ei) + + ri = df.select_dtypes(include='datetime') + ei = df[['g']] + assert_frame_equal(ri, ei) + + ri = df.select_dtypes(include='datetime64') + ei = df[['g']] + assert_frame_equal(ri, ei) + + ri = df.select_dtypes(include='category') + ei = df[['f']] + assert_frame_equal(ri, ei) + + pytest.raises(NotImplementedError, + lambda: df.select_dtypes(include='period')) + + def test_select_dtypes_exclude_using_scalars(self): + df = DataFrame({'a': list('abc'), + 'b': list(range(1, 4)), + 'c': np.arange(3, 6).astype('u1'), + 'd': np.arange(4.0, 7.0, dtype='float64'), + 'e': [True, False, True], + 'f': pd.Categorical(list('abc')), + 'g': pd.date_range('20130101', periods=3), + 'h': pd.date_range('20130101', periods=3, + tz='US/Eastern'), + 'i': pd.date_range('20130101', periods=3, + tz='CET'), + 'j': pd.period_range('2013-01', periods=3, + freq='M'), + 'k': pd.timedelta_range('1 day', periods=3)}) + + ri = df.select_dtypes(exclude=np.number) + ei = df[['a', 'e', 'f', 'g', 'h', 'i', 'j']] + assert_frame_equal(ri, ei) + + ri = df.select_dtypes(exclude='category') + ei = df[['a', 'b', 'c', 'd', 'e', 'g', 'h', 'i', 'j', 'k']] + assert_frame_equal(ri, ei) + + pytest.raises(NotImplementedError, + lambda: df.select_dtypes(exclude='period')) + + def test_select_dtypes_include_exclude_using_scalars(self): + df = DataFrame({'a': list('abc'), + 'b': list(range(1, 4)), + 'c': np.arange(3, 6).astype('u1'), + 'd': np.arange(4.0, 7.0, dtype='float64'), + 'e': [True, False, True], + 'f': pd.Categorical(list('abc')), + 'g': pd.date_range('20130101', periods=3), + 'h': pd.date_range('20130101', periods=3, + tz='US/Eastern'), + 'i': pd.date_range('20130101', periods=3, + tz='CET'), + 'j': pd.period_range('2013-01', periods=3, + freq='M'), + 'k': pd.timedelta_range('1 day', periods=3)}) + + ri = df.select_dtypes(include=np.number, exclude='floating') + ei = df[['b', 'c', 'k']] + assert_frame_equal(ri, ei) + + def test_select_dtypes_include_exclude_mixed_scalars_lists(self): + df = DataFrame({'a': list('abc'), + 'b': list(range(1, 4)), + 'c': np.arange(3, 6).astype('u1'), + 'd': np.arange(4.0, 7.0, dtype='float64'), + 'e': [True, False, True], + 'f': pd.Categorical(list('abc')), + 'g': pd.date_range('20130101', periods=3), + 'h': pd.date_range('20130101', periods=3, + tz='US/Eastern'), + 'i': pd.date_range('20130101', periods=3, + tz='CET'), + 'j': pd.period_range('2013-01', periods=3, + freq='M'), + 'k': pd.timedelta_range('1 day', periods=3)}) + + ri = df.select_dtypes(include=np.number, + exclude=['floating', 'timedelta']) + ei = df[['b', 'c']] + assert_frame_equal(ri, ei) + + ri = df.select_dtypes(include=[np.number, 'category'], + exclude='floating') + ei = df[['b', 'c', 'f', 'k']] + assert_frame_equal(ri, ei) def test_select_dtypes_not_an_attr_but_still_valid_dtype(self): df = DataFrame({'a': list('abc'),