diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py index e4eb48eb3527..bb6f47227928 100644 --- a/sdks/python/apache_beam/dataframe/frames.py +++ b/sdks/python/apache_beam/dataframe/frames.py @@ -869,6 +869,70 @@ def mask(self, cond, **kwargs): """mask is not parallelizable when ``errors="ignore"`` is specified.""" return self.where(~cond, **kwargs) + @frame_base.with_docs_from(pd.DataFrame) + @frame_base.args_to_kwargs(pd.DataFrame) + @frame_base.populate_defaults(pd.DataFrame) + def xs(self, key, axis, level, **kwargs): + """Note that ``xs(axis='index')`` will raise a ``KeyError`` at execution + time if the key does not exist in the index.""" + + if axis in ('columns', 1): + # Special case for axis=columns. This is a simple project that raises a + # KeyError at construction time for missing columns. + return frame_base.DeferredFrame.wrap( + expressions.ComputedExpression( + 'xs', + lambda df: df.xs(key, axis=axis, **kwargs), [self._expr], + requires_partition_by=partitionings.Arbitrary(), + preserves_partition_by=partitionings.Arbitrary())) + elif axis not in ('index', 0): + # Make sure that user's axis is valid + raise ValueError( + "axis must be one of ('index', 0, 'columns', 1). " + f"got {axis!r}.") + + if not isinstance(key, tuple): + key = (key, ) + + key_size = len(key) + key_series = pd.Series([key], pd.MultiIndex.from_tuples([key])) + key_expr = expressions.ConstantExpression( + key_series, proxy=key_series.iloc[:0]) + + if level is None: + reindexed = self + else: + if not isinstance(level, list): + level = [level] + + # If user specifed levels, reindex so those levels are at the beginning. + # Keep the others and preserve their order. + level = [ + l if isinstance(l, int) else list(self.index.names).index(l) + for l in level + ] + + reindexed = self.reorder_levels( + level + [i for i in range(self.index.nlevels) if i not in level]) + + def xs_partitioned(frame, key): + if not len(key): + # key is not in this partition, return empty dataframe + return frame.iloc[:0].droplevel(list(range(key_size))) + + # key should be in this partition, call xs. Will raise KeyError if not + # present. + return frame.xs(key.item()) + + return frame_base.DeferredFrame.wrap( + expressions.ComputedExpression( + 'xs', + xs_partitioned, + [reindexed._expr, key_expr], + requires_partition_by=partitionings.Index(list(range(key_size))), + # Drops index levels, so partitioning is not preserved + preserves_partition_by=partitionings.Singleton())) + @property def dtype(self): return self._expr.proxy().dtype diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py index 95c69ffb4c97..028c91a2b6f3 100644 --- a/sdks/python/apache_beam/dataframe/frames_test.py +++ b/sdks/python/apache_beam/dataframe/frames_test.py @@ -235,6 +235,25 @@ def test_get_column(self): self._run_test(lambda df: df.get('Animal'), df) self._run_test(lambda df: df.get('FOO', df.Animal), df) + def test_series_xs(self): + # pandas doctests only verify DataFrame.xs, here we verify Series.xs as well + d = { + 'num_legs': [4, 4, 2, 2], + 'num_wings': [0, 0, 2, 2], + 'class': ['mammal', 'mammal', 'mammal', 'bird'], + 'animal': ['cat', 'dog', 'bat', 'penguin'], + 'locomotion': ['walks', 'walks', 'flies', 'walks'] + } + df = pd.DataFrame(data=d) + df = df.set_index(['class', 'animal', 'locomotion']) + + self._run_test(lambda df: df.num_legs.xs('mammal'), df) + self._run_test(lambda df: df.num_legs.xs(('mammal', 'dog')), df) + self._run_test(lambda df: df.num_legs.xs('cat', level=1), df) + self._run_test( + lambda df: df.num_legs.xs(('bird', 'walks'), level=[0, 'locomotion']), + df) + def test_set_column(self): def new_column(df): df['NewCol'] = df['Speed'] diff --git a/sdks/python/apache_beam/dataframe/pandas_doctests_test.py b/sdks/python/apache_beam/dataframe/pandas_doctests_test.py index 44c20241a14c..e81e463dfdda 100644 --- a/sdks/python/apache_beam/dataframe/pandas_doctests_test.py +++ b/sdks/python/apache_beam/dataframe/pandas_doctests_test.py @@ -102,7 +102,6 @@ def test_ndframe_tests(self): 'pandas.core.generic.NDFrame.set_flags': ['*'], 'pandas.core.generic.NDFrame.squeeze': ['*'], 'pandas.core.generic.NDFrame.truncate': ['*'], - 'pandas.core.generic.NDFrame.xs': ['*'], }, skip={ # Internal test