diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py index 29b93cfcc91c..d87c2dfdcddb 100644 --- a/sdks/python/apache_beam/dataframe/frames.py +++ b/sdks/python/apache_beam/dataframe/frames.py @@ -4721,13 +4721,72 @@ def repeat(self, repeats): pd.core.strings.StringMethods, 'get_dummies', reason='non-deferred-columns') - split = frame_base.wont_implement_method( - pd.core.strings.StringMethods, 'split', - reason='non-deferred-columns') + def _split_helper(self, rsplit=False, **kwargs): + expand = kwargs.get('expand', False) - rsplit = frame_base.wont_implement_method( - pd.core.strings.StringMethods, 'rsplit', - reason='non-deferred-columns') + if not expand: + # Not creating separate columns + proxy = self._expr.proxy() + if not rsplit: + func = lambda s: pd.concat([proxy, s.str.split(**kwargs)]) + else: + func = lambda s: pd.concat([proxy, s.str.rsplit(**kwargs)]) + else: + # Creating separate columns, so must be more strict on dtype + dtype = self._expr.proxy().dtype + if not isinstance(dtype, pd.CategoricalDtype): + method_name = 'rsplit' if rsplit else 'split' + raise frame_base.WontImplementError( + f"{method_name}() of non-categorical type is not supported because " + "the type of the output column depends on the data. Please use " + "pd.CategoricalDtype with explicit categories.", + reason="non-deferred-columns") + + # Split the categories + split_cats = dtype.categories.str.split(**kwargs) + + # Count the number of new columns to create for proxy + max_splits = len(max(split_cats, key=len)) + proxy = pd.DataFrame(columns=range(max_splits)) + + def func(s): + if not rsplit: + result = s.str.split(**kwargs) + else: + result = s.str.rsplit(**kwargs) + result[~result.isna()].replace(np.nan, value=None) + return result + + return frame_base.DeferredFrame.wrap( + expressions.ComputedExpression( + 'split', + func, + [self._expr], + proxy=proxy, + requires_partition_by=partitionings.Arbitrary(), + preserves_partition_by=partitionings.Arbitrary())) + + @frame_base.with_docs_from(pd.core.strings.StringMethods) + @frame_base.args_to_kwargs(pd.core.strings.StringMethods) + @frame_base.populate_defaults(pd.core.strings.StringMethods) + def split(self, **kwargs): + """ + Like other non-deferred methods, dtype must be CategoricalDtype. + One exception is when ``expand`` is ``False``. Because we are not + creating new columns at construction time, dtype can be `str`. + """ + return self._split_helper(rsplit=False, **kwargs) + + @frame_base.with_docs_from(pd.core.strings.StringMethods) + @frame_base.args_to_kwargs(pd.core.strings.StringMethods) + @frame_base.populate_defaults(pd.core.strings.StringMethods) + def rsplit(self, **kwargs): + """ + Like other non-deferred methods, dtype must be CategoricalDtype. + One exception is when ``expand`` is ``False``. Because we are not + creating new columns at construction time, dtype can be `str`. + """ + return self._split_helper(rsplit=True, **kwargs) ELEMENTWISE_STRING_METHODS = [ diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py index 6b48e31128de..0cff789bcccf 100644 --- a/sdks/python/apache_beam/dataframe/frames_test.py +++ b/sdks/python/apache_beam/dataframe/frames_test.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re import unittest import numpy as np @@ -2247,6 +2248,117 @@ def test_sample_with_weights_distribution(self): expected = num_samples * target_prob self.assertTrue(expected / 3 < result < expected * 2, (expected, result)) + def test_split_pandas_examples_no_expand(self): + # if expand=False (default), then no need to cast dtype to be + # CategoricalDtype. + s = pd.Series([ + "this is a regular sentence", + "https://docs.python.org/3/tutorial/index.html", + np.nan + ]) + result = self._evaluate(lambda s: s.str.split(), s) + self.assert_frame_data_equivalent(result, s.str.split()) + + result = self._evaluate(lambda s: s.str.rsplit(), s) + self.assert_frame_data_equivalent(result, s.str.rsplit()) + + result = self._evaluate(lambda s: s.str.split(n=2), s) + self.assert_frame_data_equivalent(result, s.str.split(n=2)) + + result = self._evaluate(lambda s: s.str.rsplit(n=2), s) + self.assert_frame_data_equivalent(result, s.str.rsplit(n=2)) + + result = self._evaluate(lambda s: s.str.split(pat="/"), s) + self.assert_frame_data_equivalent(result, s.str.split(pat="/")) + + def test_split_pandas_examples_expand_not_categorical(self): + # When expand=True, there is exception because series is not categorical + s = pd.Series([ + "this is a regular sentence", + "https://docs.python.org/3/tutorial/index.html", + np.nan + ]) + with self.assertRaisesRegex( + frame_base.WontImplementError, + r"split\(\) of non-categorical type is not supported"): + self._evaluate(lambda s: s.str.split(expand=True), s) + + with self.assertRaisesRegex( + frame_base.WontImplementError, + r"rsplit\(\) of non-categorical type is not supported"): + self._evaluate(lambda s: s.str.rsplit(expand=True), s) + + def test_split_pandas_examples_expand_pat_is_string_literal1(self): + # When expand=True and pattern is treated as a string literal + s = pd.Series([ + "this is a regular sentence", + "https://docs.python.org/3/tutorial/index.html", + np.nan + ]) + s = s.astype( + pd.CategoricalDtype( + categories=[ + 'this is a regular sentence', + 'https://docs.python.org/3/tutorial/index.html' + ])) + result = self._evaluate(lambda s: s.str.split(expand=True), s) + self.assert_frame_data_equivalent(result, s.str.split(expand=True)) + + result = self._evaluate(lambda s: s.str.rsplit("/", n=1, expand=True), s) + self.assert_frame_data_equivalent( + result, s.str.rsplit("/", n=1, expand=True)) + + @unittest.skipIf(PD_VERSION < (1, 4), "regex arg is new in pandas 1.4") + def test_split_pandas_examples_expand_pat_is_string_literal2(self): + # when regex is None (default) regex pat is string literal if len(pat) == 1 + s = pd.Series(['foojpgbar.jpg']).astype('category') + s = s.astype(pd.CategoricalDtype(categories=["foojpgbar.jpg"])) + result = self._evaluate(lambda s: s.str.split(r".", expand=True), s) + self.assert_frame_data_equivalent(result, s.str.split(r".", expand=True)) + + # When regex=False, pat is interpreted as the string itself + result = self._evaluate( + lambda s: s.str.split(r"\.jpg", regex=False, expand=True), s) + self.assert_frame_data_equivalent( + result, s.str.split(r"\.jpg", regex=False, expand=True)) + + @unittest.skipIf(PD_VERSION < (1, 4), "regex arg is new in pandas 1.4") + def test_split_pandas_examples_expand_pat_is_regex(self): + # when regex is None (default) regex pat is compiled if len(pat) != 1 + s = pd.Series(["foo and bar plus baz"]) + s = s.astype(pd.CategoricalDtype(categories=["foo and bar plus baz"])) + result = self._evaluate(lambda s: s.str.split(r"and|plus", expand=True), s) + self.assert_frame_data_equivalent( + result, s.str.split(r"and|plus", expand=True)) + + s = pd.Series(['foojpgbar.jpg']).astype('category') + s = s.astype(pd.CategoricalDtype(categories=["foojpgbar.jpg"])) + result = self._evaluate(lambda s: s.str.split(r"\.jpg", expand=True), s) + self.assert_frame_data_equivalent( + result, s.str.split(r"\.jpg", expand=True)) + + # When regex=True, pat is interpreted as a regex + result = self._evaluate( + lambda s: s.str.split(r"\.jpg", regex=True, expand=True), s) + self.assert_frame_data_equivalent( + result, s.str.split(r"\.jpg", regex=True, expand=True)) + + # A compiled regex can be passed as pat + result = self._evaluate( + lambda s: s.str.split(re.compile(r"\.jpg"), expand=True), s) + self.assert_frame_data_equivalent( + result, s.str.split(re.compile(r"\.jpg"), expand=True)) + + @unittest.skipIf(PD_VERSION < (1, 4), "regex arg is new in pandas 1.4") + def test_split_pat_is_regex(self): + # regex=True, but expand=False + s = pd.Series(['foojpgbar.jpg']).astype('category') + s = s.astype(pd.CategoricalDtype(categories=["foojpgbar.jpg"])) + result = self._evaluate( + lambda s: s.str.split(r"\.jpg", regex=True, expand=False), s) + self.assert_frame_data_equivalent( + result, s.str.split(r"\.jpg", regex=True, expand=False)) + class AllowNonParallelTest(unittest.TestCase): def _use_non_parallel_operation(self): diff --git a/sdks/python/apache_beam/dataframe/pandas_doctests_test.py b/sdks/python/apache_beam/dataframe/pandas_doctests_test.py index 99b64c03d2d0..34777e605719 100644 --- a/sdks/python/apache_beam/dataframe/pandas_doctests_test.py +++ b/sdks/python/apache_beam/dataframe/pandas_doctests_test.py @@ -584,8 +584,6 @@ def test_string_tests(self): f'{module_name}.StringMethods.get_dummies': ['*'], f'{module_name}.str_get_dummies': ['*'], f'{module_name}.StringMethods': ['s.str.split("_")'], - f'{module_name}.StringMethods.rsplit': ['*'], - f'{module_name}.StringMethods.split': ['*'], }, skip={ # count() on Series with a NaN produces mismatched type if we @@ -602,7 +600,32 @@ def test_string_tests(self): ], # output has incorrect formatting in 1.2.x - f'{module_name}.StringMethods.extractall': ['*'] + f'{module_name}.StringMethods.extractall': ['*'], + + # For split and rsplit, if expand=True, then the series + # must be of CategoricalDtype, which pandas doesn't convert to + f'{module_name}.StringMethods.rsplit': [ + 's.str.split(r"\\+|=", expand=True)', # for pandas<1.4 + 's.str.split(expand=True)', + 's.str.rsplit("/", n=1, expand=True)', + 's.str.split(r"and|plus", expand=True)', + 's.str.split(r".", expand=True)', + 's.str.split(r"\\.jpg", expand=True)', + 's.str.split(r"\\.jpg", regex=True, expand=True)', + 's.str.split(re.compile(r"\\.jpg"), expand=True)', + 's.str.split(r"\\.jpg", regex=False, expand=True)' + ], + f'{module_name}.StringMethods.split': [ + 's.str.split(r"\\+|=", expand=True)', # for pandas<1.4 + 's.str.split(expand=True)', + 's.str.rsplit("/", n=1, expand=True)', + 's.str.split(r"and|plus", expand=True)', + 's.str.split(r".", expand=True)', + 's.str.split(r"\\.jpg", expand=True)', + 's.str.split(r"\\.jpg", regex=True, expand=True)', + 's.str.split(re.compile(r"\\.jpg"), expand=True)', + 's.str.split(r"\\.jpg", regex=False, expand=True)' + ] }) self.assertEqual(result.failed, 0)