From 225fcd0d7be8970f010102c6f9c34a7deb8e0c5c Mon Sep 17 00:00:00 2001
From: Brian Hulette <bhulette@google.com>
Date: Fri, 25 Jun 2021 15:19:44 -0700
Subject: [PATCH] [BEAM-9547] Add support for xs on DataFrame and Series
 (#15078)

* Add support for xs on DataFrame and Series

* add comments
---
 sdks/python/apache_beam/dataframe/frames.py   | 64 +++++++++++++++++++
 .../apache_beam/dataframe/frames_test.py      | 19 ++++++
 .../dataframe/pandas_doctests_test.py         |  1 -
 3 files changed, 83 insertions(+), 1 deletion(-)

diff --git a/sdks/python/apache_beam/dataframe/frames.py b/sdks/python/apache_beam/dataframe/frames.py
index e4eb48eb3527..bb6f47227928 100644
--- a/sdks/python/apache_beam/dataframe/frames.py
+++ b/sdks/python/apache_beam/dataframe/frames.py
@@ -869,6 +869,70 @@ def mask(self, cond, **kwargs):
     """mask is not parallelizable when ``errors="ignore"`` is specified."""
     return self.where(~cond, **kwargs)
 
+  @frame_base.with_docs_from(pd.DataFrame)
+  @frame_base.args_to_kwargs(pd.DataFrame)
+  @frame_base.populate_defaults(pd.DataFrame)
+  def xs(self, key, axis, level, **kwargs):
+    """Note that ``xs(axis='index')`` will raise a ``KeyError`` at execution
+    time if the key does not exist in the index."""
+
+    if axis in ('columns', 1):
+      # Special case for axis=columns. This is a simple project that raises a
+      # KeyError at construction time for missing columns.
+      return frame_base.DeferredFrame.wrap(
+          expressions.ComputedExpression(
+              'xs',
+              lambda df: df.xs(key, axis=axis, **kwargs), [self._expr],
+              requires_partition_by=partitionings.Arbitrary(),
+              preserves_partition_by=partitionings.Arbitrary()))
+    elif axis not in ('index', 0):
+      # Make sure that user's axis is valid
+      raise ValueError(
+          "axis must be one of ('index', 0, 'columns', 1). "
+          f"got {axis!r}.")
+
+    if not isinstance(key, tuple):
+      key = (key, )
+
+    key_size = len(key)
+    key_series = pd.Series([key], pd.MultiIndex.from_tuples([key]))
+    key_expr = expressions.ConstantExpression(
+        key_series, proxy=key_series.iloc[:0])
+
+    if level is None:
+      reindexed = self
+    else:
+      if not isinstance(level, list):
+        level = [level]
+
+      # If user specifed levels, reindex so those levels are at the beginning.
+      # Keep the others and preserve their order.
+      level = [
+          l if isinstance(l, int) else list(self.index.names).index(l)
+          for l in level
+      ]
+
+      reindexed = self.reorder_levels(
+          level + [i for i in range(self.index.nlevels) if i not in level])
+
+    def xs_partitioned(frame, key):
+      if not len(key):
+        # key is not in this partition, return empty dataframe
+        return frame.iloc[:0].droplevel(list(range(key_size)))
+
+      # key should be in this partition, call xs. Will raise KeyError if not
+      # present.
+      return frame.xs(key.item())
+
+    return frame_base.DeferredFrame.wrap(
+        expressions.ComputedExpression(
+            'xs',
+            xs_partitioned,
+            [reindexed._expr, key_expr],
+            requires_partition_by=partitionings.Index(list(range(key_size))),
+            # Drops index levels, so partitioning is not preserved
+            preserves_partition_by=partitionings.Singleton()))
+
   @property
   def dtype(self):
     return self._expr.proxy().dtype
diff --git a/sdks/python/apache_beam/dataframe/frames_test.py b/sdks/python/apache_beam/dataframe/frames_test.py
index 95c69ffb4c97..028c91a2b6f3 100644
--- a/sdks/python/apache_beam/dataframe/frames_test.py
+++ b/sdks/python/apache_beam/dataframe/frames_test.py
@@ -235,6 +235,25 @@ def test_get_column(self):
     self._run_test(lambda df: df.get('Animal'), df)
     self._run_test(lambda df: df.get('FOO', df.Animal), df)
 
+  def test_series_xs(self):
+    # pandas doctests only verify DataFrame.xs, here we verify Series.xs as well
+    d = {
+        'num_legs': [4, 4, 2, 2],
+        'num_wings': [0, 0, 2, 2],
+        'class': ['mammal', 'mammal', 'mammal', 'bird'],
+        'animal': ['cat', 'dog', 'bat', 'penguin'],
+        'locomotion': ['walks', 'walks', 'flies', 'walks']
+    }
+    df = pd.DataFrame(data=d)
+    df = df.set_index(['class', 'animal', 'locomotion'])
+
+    self._run_test(lambda df: df.num_legs.xs('mammal'), df)
+    self._run_test(lambda df: df.num_legs.xs(('mammal', 'dog')), df)
+    self._run_test(lambda df: df.num_legs.xs('cat', level=1), df)
+    self._run_test(
+        lambda df: df.num_legs.xs(('bird', 'walks'), level=[0, 'locomotion']),
+        df)
+
   def test_set_column(self):
     def new_column(df):
       df['NewCol'] = df['Speed']
diff --git a/sdks/python/apache_beam/dataframe/pandas_doctests_test.py b/sdks/python/apache_beam/dataframe/pandas_doctests_test.py
index 44c20241a14c..e81e463dfdda 100644
--- a/sdks/python/apache_beam/dataframe/pandas_doctests_test.py
+++ b/sdks/python/apache_beam/dataframe/pandas_doctests_test.py
@@ -102,7 +102,6 @@ def test_ndframe_tests(self):
             'pandas.core.generic.NDFrame.set_flags': ['*'],
             'pandas.core.generic.NDFrame.squeeze': ['*'],
             'pandas.core.generic.NDFrame.truncate': ['*'],
-            'pandas.core.generic.NDFrame.xs': ['*'],
         },
         skip={
             # Internal test