Skip to content

Commit

Permalink
ENH: hack toward #629
Browse files Browse the repository at this point in the history
  • Loading branch information
wesm committed Jan 15, 2012
1 parent ac26c84 commit 41ec919
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 22 deletions.
11 changes: 10 additions & 1 deletion pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,7 @@ def pop(self, item):
def _series(self):
return self._data.get_series_dict()

def xs(self, key, axis=0, copy=True):
def xs(self, key, axis=0, level=None, copy=True):
"""
Returns a cross-section (row or column) from the DataFrame as a Series
object. Defaults to returning a row (axis 0)
Expand All @@ -1413,6 +1413,15 @@ def xs(self, key, axis=0, copy=True):
-------
xs : Series
"""
labels = self._get_axis(axis)
if level is not None:
indexer = [slice(None, None)] * 2
indexer[axis] = labels.get_loc_level(key, level=level)
result = self.ix[tuple(indexer)]
new_ax = result._get_axis(axis).droplevel(level)
setattr(result, result._get_axis_name(axis), new_ax)
return result

if axis == 1:
data = self[key]
if copy:
Expand Down
67 changes: 58 additions & 9 deletions pandas/core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -1608,22 +1608,71 @@ def get_loc(self, key):
if len(key) == self.nlevels:
return self._engine.get_loc(key)
else:
# partial selection
result = slice(*self.slice_locs(key, key))
if result.start == result.stop:
raise KeyError(key)
return result
else:
level = self.levels[0]
labels = self.labels[0]
loc = level.get_loc(key)
return self._get_level_indexer(key, level=0)

if self.lexsort_depth == 0:
return labels == loc
def get_loc_level(self, key, level=0):
"""
Get integer location slice for requested label or tuple
Parameters
----------
key : label or tuple
Returns
-------
loc : int or slice object
"""
if isinstance(key, tuple) and level == 0:
if not any(isinstance(k, slice) for k in key):
if len(key) == self.nlevels:
return self._engine.get_loc(key)
else:
# partial selection
result = slice(*self.slice_locs(key, key))
if result.start == result.stop:
raise KeyError(key)
return result
else:
# sorted, so can return slice object -> view
i = labels.searchsorted(loc, side='left')
j = labels.searchsorted(loc, side='right')
return slice(i, j)
indexer = None
for i, k in enumerate(key):
if k is None:
continue

if isinstance(k, slice):
if k == slice(None, None):
continue
else:
k_index = np.empty(len(self), dtype=bool)
k_index[k] = True
else:
k_index = self._get_level_indexer(k, level=i)

if indexer is None:
indexer = k_index
else:
indexer &= k_index
return indexer
else:
return self._get_level_indexer(key, level=level)

def _get_level_indexer(self, key, level=0):
level_index = self.levels[level]
loc = level_index.get_loc(key)
labels = self.labels[level]

if level > 0 or self.lexsort_depth == 0:
return labels == loc
else:
# sorted, so can return slice object -> view
i = labels.searchsorted(loc, side='left')
j = labels.searchsorted(loc, side='right')
return slice(i, j)

def truncate(self, before=None, after=None):
"""
Expand Down
40 changes: 28 additions & 12 deletions pandas/core/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,12 +308,8 @@ def _get_with(self, key):
indexer = self.ix._convert_to_indexer(key, axis=0)
return self._get_values(indexer)
else:
# mpl hackaround
if isinstance(key, tuple):
try:
return self._get_values(key)
except Exception:
pass
return self._get_values_tuple(key)

if not isinstance(key, (list, np.ndarray)):
key = list(key)
Expand All @@ -338,6 +334,33 @@ def _get_with(self, key):
return self._get_values(key)
raise

def _get_values_tuple(self, key):
# mpl hackaround
if any(k is None for k in key):
return self._get_values(key)

if not isinstance(self.index, MultiIndex):
raise ValueError('Can only tuple-index with a MultiIndex')

indexer = self.index.get_loc_level(key)
result = self._get_values(indexer)

# kludgearound
new_index = result.index
for i, k in reversed(list(enumerate(key))):
if k != slice(None, None):
new_index = new_index.droplevel(i)
result.index = new_index

return result

def _get_values(self, indexer):
try:
return Series(self.values[indexer], index=self.index[indexer],
name=self.name)
except Exception:
return self.values[indexer]

def __setitem__(self, key, value):
values = self.values
try:
Expand Down Expand Up @@ -397,13 +420,6 @@ def _set_labels(self, key, value):
% str(key[mask]))
self._set_values(indexer, value)

def _get_values(self, indexer):
try:
return Series(self.values[indexer], index=self.index[indexer],
name=self.name)
except Exception:
return self.values[indexer]

def _set_values(self, key, value):
self.values[key] = value

Expand Down
13 changes: 13 additions & 0 deletions pandas/tests/test_multilevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,19 @@ def test_xs_partial(self):
assert_frame_equal(result, expected)
assert_frame_equal(result, result2)

def test_xs_level(self):
result = self.frame.xs('two', level=1)
expected = self.frame[self.frame.index.get_level_values(1) == 'two']
expected.index = expected.index.droplevel(1)

assert_frame_equal(result, expected)

def test_xs_level_series(self):
s = self.frame['A']
result = s[:, 'two']
expected = self.frame.xs('two', level=1)['A']
assert_series_equal(result, expected)

def test_fancy_2d(self):
result = self.frame.ix['foo', 'B']
expected = self.frame.xs('foo')['B']
Expand Down

0 comments on commit 41ec919

Please sign in to comment.