Skip to content

Commit

Permalink
[BEAM-9547] Add support for xs on DataFrame and Series (#15078)
Browse files Browse the repository at this point in the history
* Add support for xs on DataFrame and Series

* add comments
  • Loading branch information
TheNeuralBit authored Jun 25, 2021
1 parent dd01eb9 commit 225fcd0
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 1 deletion.
64 changes: 64 additions & 0 deletions sdks/python/apache_beam/dataframe/frames.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,70 @@ def mask(self, cond, **kwargs):
"""mask is not parallelizable when ``errors="ignore"`` is specified."""
return self.where(~cond, **kwargs)

@frame_base.with_docs_from(pd.DataFrame)
@frame_base.args_to_kwargs(pd.DataFrame)
@frame_base.populate_defaults(pd.DataFrame)
def xs(self, key, axis, level, **kwargs):
"""Note that ``xs(axis='index')`` will raise a ``KeyError`` at execution
time if the key does not exist in the index."""

if axis in ('columns', 1):
# Special case for axis=columns. This is a simple project that raises a
# KeyError at construction time for missing columns.
return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'xs',
lambda df: df.xs(key, axis=axis, **kwargs), [self._expr],
requires_partition_by=partitionings.Arbitrary(),
preserves_partition_by=partitionings.Arbitrary()))
elif axis not in ('index', 0):
# Make sure that user's axis is valid
raise ValueError(
"axis must be one of ('index', 0, 'columns', 1). "
f"got {axis!r}.")

if not isinstance(key, tuple):
key = (key, )

key_size = len(key)
key_series = pd.Series([key], pd.MultiIndex.from_tuples([key]))
key_expr = expressions.ConstantExpression(
key_series, proxy=key_series.iloc[:0])

if level is None:
reindexed = self
else:
if not isinstance(level, list):
level = [level]

# If user specifed levels, reindex so those levels are at the beginning.
# Keep the others and preserve their order.
level = [
l if isinstance(l, int) else list(self.index.names).index(l)
for l in level
]

reindexed = self.reorder_levels(
level + [i for i in range(self.index.nlevels) if i not in level])

def xs_partitioned(frame, key):
if not len(key):
# key is not in this partition, return empty dataframe
return frame.iloc[:0].droplevel(list(range(key_size)))

# key should be in this partition, call xs. Will raise KeyError if not
# present.
return frame.xs(key.item())

return frame_base.DeferredFrame.wrap(
expressions.ComputedExpression(
'xs',
xs_partitioned,
[reindexed._expr, key_expr],
requires_partition_by=partitionings.Index(list(range(key_size))),
# Drops index levels, so partitioning is not preserved
preserves_partition_by=partitionings.Singleton()))

@property
def dtype(self):
return self._expr.proxy().dtype
Expand Down
19 changes: 19 additions & 0 deletions sdks/python/apache_beam/dataframe/frames_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,25 @@ def test_get_column(self):
self._run_test(lambda df: df.get('Animal'), df)
self._run_test(lambda df: df.get('FOO', df.Animal), df)

def test_series_xs(self):
# pandas doctests only verify DataFrame.xs, here we verify Series.xs as well
d = {
'num_legs': [4, 4, 2, 2],
'num_wings': [0, 0, 2, 2],
'class': ['mammal', 'mammal', 'mammal', 'bird'],
'animal': ['cat', 'dog', 'bat', 'penguin'],
'locomotion': ['walks', 'walks', 'flies', 'walks']
}
df = pd.DataFrame(data=d)
df = df.set_index(['class', 'animal', 'locomotion'])

self._run_test(lambda df: df.num_legs.xs('mammal'), df)
self._run_test(lambda df: df.num_legs.xs(('mammal', 'dog')), df)
self._run_test(lambda df: df.num_legs.xs('cat', level=1), df)
self._run_test(
lambda df: df.num_legs.xs(('bird', 'walks'), level=[0, 'locomotion']),
df)

def test_set_column(self):
def new_column(df):
df['NewCol'] = df['Speed']
Expand Down
1 change: 0 additions & 1 deletion sdks/python/apache_beam/dataframe/pandas_doctests_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def test_ndframe_tests(self):
'pandas.core.generic.NDFrame.set_flags': ['*'],
'pandas.core.generic.NDFrame.squeeze': ['*'],
'pandas.core.generic.NDFrame.truncate': ['*'],
'pandas.core.generic.NDFrame.xs': ['*'],
},
skip={
# Internal test
Expand Down

0 comments on commit 225fcd0

Please sign in to comment.