diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 0425cd7bba106..046bcac674a28 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -1395,7 +1395,7 @@ def pop(self, item): def _series(self): return self._data.get_series_dict() - def xs(self, key, axis=0, copy=True): + def xs(self, key, axis=0, level=None, copy=True): """ Returns a cross-section (row or column) from the DataFrame as a Series object. Defaults to returning a row (axis 0) @@ -1413,6 +1413,15 @@ def xs(self, key, axis=0, copy=True): ------- xs : Series """ + labels = self._get_axis(axis) + if level is not None: + indexer = [slice(None, None)] * 2 + indexer[axis] = labels.get_loc_level(key, level=level) + result = self.ix[tuple(indexer)] + new_ax = result._get_axis(axis).droplevel(level) + setattr(result, result._get_axis_name(axis), new_ax) + return result + if axis == 1: data = self[key] if copy: diff --git a/pandas/core/index.py b/pandas/core/index.py index 26b93a49cc3e8..9d4f496eb2d0a 100644 --- a/pandas/core/index.py +++ b/pandas/core/index.py @@ -1608,22 +1608,71 @@ def get_loc(self, key): if len(key) == self.nlevels: return self._engine.get_loc(key) else: + # partial selection result = slice(*self.slice_locs(key, key)) if result.start == result.stop: raise KeyError(key) return result else: - level = self.levels[0] - labels = self.labels[0] - loc = level.get_loc(key) + return self._get_level_indexer(key, level=0) - if self.lexsort_depth == 0: - return labels == loc + def get_loc_level(self, key, level=0): + """ + Get integer location slice for requested label or tuple + + Parameters + ---------- + key : label or tuple + + Returns + ------- + loc : int or slice object + """ + if isinstance(key, tuple) and level == 0: + if not any(isinstance(k, slice) for k in key): + if len(key) == self.nlevels: + return self._engine.get_loc(key) + else: + # partial selection + result = slice(*self.slice_locs(key, key)) + if result.start == result.stop: + raise KeyError(key) + return result else: - # sorted, so can return slice object -> view - i = labels.searchsorted(loc, side='left') - j = labels.searchsorted(loc, side='right') - return slice(i, j) + indexer = None + for i, k in enumerate(key): + if k is None: + continue + + if isinstance(k, slice): + if k == slice(None, None): + continue + else: + k_index = np.empty(len(self), dtype=bool) + k_index[k] = True + else: + k_index = self._get_level_indexer(k, level=i) + + if indexer is None: + indexer = k_index + else: + indexer &= k_index + return indexer + else: + return self._get_level_indexer(key, level=level) + + def _get_level_indexer(self, key, level=0): + level_index = self.levels[level] + loc = level_index.get_loc(key) + labels = self.labels[level] + + if level > 0 or self.lexsort_depth == 0: + return labels == loc + else: + # sorted, so can return slice object -> view + i = labels.searchsorted(loc, side='left') + j = labels.searchsorted(loc, side='right') + return slice(i, j) def truncate(self, before=None, after=None): """ diff --git a/pandas/core/series.py b/pandas/core/series.py index 8fe51364f867f..2bf6dfcf58717 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -308,12 +308,8 @@ def _get_with(self, key): indexer = self.ix._convert_to_indexer(key, axis=0) return self._get_values(indexer) else: - # mpl hackaround if isinstance(key, tuple): - try: - return self._get_values(key) - except Exception: - pass + return self._get_values_tuple(key) if not isinstance(key, (list, np.ndarray)): key = list(key) @@ -338,6 +334,33 @@ def _get_with(self, key): return self._get_values(key) raise + def _get_values_tuple(self, key): + # mpl hackaround + if any(k is None for k in key): + return self._get_values(key) + + if not isinstance(self.index, MultiIndex): + raise ValueError('Can only tuple-index with a MultiIndex') + + indexer = self.index.get_loc_level(key) + result = self._get_values(indexer) + + # kludgearound + new_index = result.index + for i, k in reversed(list(enumerate(key))): + if k != slice(None, None): + new_index = new_index.droplevel(i) + result.index = new_index + + return result + + def _get_values(self, indexer): + try: + return Series(self.values[indexer], index=self.index[indexer], + name=self.name) + except Exception: + return self.values[indexer] + def __setitem__(self, key, value): values = self.values try: @@ -397,13 +420,6 @@ def _set_labels(self, key, value): % str(key[mask])) self._set_values(indexer, value) - def _get_values(self, indexer): - try: - return Series(self.values[indexer], index=self.index[indexer], - name=self.name) - except Exception: - return self.values[indexer] - def _set_values(self, key, value): self.values[key] = value diff --git a/pandas/tests/test_multilevel.py b/pandas/tests/test_multilevel.py index 7d7fbaf73aa47..f3b6beab4ea78 100644 --- a/pandas/tests/test_multilevel.py +++ b/pandas/tests/test_multilevel.py @@ -203,6 +203,19 @@ def test_xs_partial(self): assert_frame_equal(result, expected) assert_frame_equal(result, result2) + def test_xs_level(self): + result = self.frame.xs('two', level=1) + expected = self.frame[self.frame.index.get_level_values(1) == 'two'] + expected.index = expected.index.droplevel(1) + + assert_frame_equal(result, expected) + + def test_xs_level_series(self): + s = self.frame['A'] + result = s[:, 'two'] + expected = self.frame.xs('two', level=1)['A'] + assert_series_equal(result, expected) + def test_fancy_2d(self): result = self.frame.ix['foo', 'B'] expected = self.frame.xs('foo')['B']