From 6858d0f6caa60c98acc4b6c3eaa6cd0309aedca6 Mon Sep 17 00:00:00 2001 From: Kevin Sheppard Date: Fri, 14 Jul 2017 22:20:28 +0100 Subject: [PATCH] BUG: Allow value labels to be read with iterator (#16926) All value labels to be read before the iterator has been used Fix issue where categorical data was incorrectly reformatted when write_index was False closes #16923 --- doc/source/whatsnew/v0.21.0.txt | 1 + pandas/io/stata.py | 36 ++++++++++++++++++--------------- pandas/tests/io/test_stata.py | 18 ++++++++++++++--- 3 files changed, 36 insertions(+), 19 deletions(-) diff --git a/doc/source/whatsnew/v0.21.0.txt b/doc/source/whatsnew/v0.21.0.txt index 2716d9b09eaa9..bd19d71182762 100644 --- a/doc/source/whatsnew/v0.21.0.txt +++ b/doc/source/whatsnew/v0.21.0.txt @@ -162,6 +162,7 @@ I/O - Bug in :func:`read_csv` in which non integer values for the header argument generated an unhelpful / unrelated error message (:issue:`16338`) +- Bug in :func:`read_stata` where value labels could not be read when using an iterator (:issue:`16923`) Plotting ^^^^^^^^ diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 107dccfc8175c..30991d8a24c63 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -997,6 +997,7 @@ def __init__(self, path_or_buf, convert_dates=True, self.path_or_buf = BytesIO(contents) self._read_header() + self._setup_dtype() def __enter__(self): """ enter context manager """ @@ -1299,6 +1300,23 @@ def _read_old_header(self, first_char): # necessary data to continue parsing self.data_location = self.path_or_buf.tell() + def _setup_dtype(self): + """Map between numpy and state dtypes""" + if self._dtype is not None: + return self._dtype + + dtype = [] # Convert struct data types to numpy data type + for i, typ in enumerate(self.typlist): + if typ in self.NUMPY_TYPE_MAP: + dtype.append(('s' + str(i), self.byteorder + + self.NUMPY_TYPE_MAP[typ])) + else: + dtype.append(('s' + str(i), 'S' + str(typ))) + dtype = np.dtype(dtype) + self._dtype = dtype + + return self._dtype + def _calcsize(self, fmt): return (type(fmt) is int and fmt or struct.calcsize(self.byteorder + fmt)) @@ -1472,22 +1490,10 @@ def read(self, nrows=None, convert_dates=None, if nrows is None: nrows = self.nobs - if (self.format_version >= 117) and (self._dtype is None): + if (self.format_version >= 117) and (not self._value_labels_read): self._can_read_value_labels = True self._read_strls() - # Setup the dtype. - if self._dtype is None: - dtype = [] # Convert struct data types to numpy data type - for i, typ in enumerate(self.typlist): - if typ in self.NUMPY_TYPE_MAP: - dtype.append(('s' + str(i), self.byteorder + - self.NUMPY_TYPE_MAP[typ])) - else: - dtype.append(('s' + str(i), 'S' + str(typ))) - dtype = np.dtype(dtype) - self._dtype = dtype - # Read data dtype = self._dtype max_read_len = (self.nobs - self._lines_read) * dtype.itemsize @@ -1958,7 +1964,6 @@ def _prepare_categoricals(self, data): return data get_base_missing_value = StataMissingValue.get_base_missing_value - index = data.index data_formatted = [] for col, col_is_cat in zip(data, is_cat): if col_is_cat: @@ -1981,8 +1986,7 @@ def _prepare_categoricals(self, data): # Replace missing values with Stata missing value for type values[values == -1] = get_base_missing_value(dtype) - data_formatted.append((col, values, index)) - + data_formatted.append((col, values)) else: data_formatted.append((col, data[col])) return DataFrame.from_items(data_formatted) diff --git a/pandas/tests/io/test_stata.py b/pandas/tests/io/test_stata.py index b9c6736563160..a414928d318c4 100644 --- a/pandas/tests/io/test_stata.py +++ b/pandas/tests/io/test_stata.py @@ -9,18 +9,18 @@ from datetime import datetime from distutils.version import LooseVersion -import pytest import numpy as np import pandas as pd import pandas.util.testing as tm +import pytest from pandas import compat +from pandas._libs.tslib import NaT from pandas.compat import iterkeys +from pandas.core.dtypes.common import is_categorical_dtype from pandas.core.frame import DataFrame, Series from pandas.io.parsers import read_csv from pandas.io.stata import (read_stata, StataReader, InvalidColumnName, PossiblePrecisionLoss, StataMissingValue) -from pandas._libs.tslib import NaT -from pandas.core.dtypes.common import is_categorical_dtype class TestStata(object): @@ -1297,3 +1297,15 @@ def test_pickle_path_localpath(self): reader = lambda x: read_stata(x).set_index('index') result = tm.round_trip_localpath(df.to_stata, reader) tm.assert_frame_equal(df, result) + + @pytest.mark.parametrize('write_index', [True, False]) + def test_value_labels_iterator(self, write_index): + # GH 16923 + d = {'A': ['B', 'E', 'C', 'A', 'E']} + df = pd.DataFrame(data=d) + df['A'] = df['A'].astype('category') + with tm.ensure_clean() as path: + df.to_stata(path, write_index=write_index) + dta_iter = pd.read_stata(path, iterator=True) + value_labels = dta_iter.value_labels() + assert value_labels == {'A': {0: 'A', 1: 'B', 2: 'C', 3: 'E'}}