Skip to content

Commit

Permalink
BUG: Allow value labels to be read with iterator
Browse files Browse the repository at this point in the history
All value labels to be read before the iterator has been used
Fix issue where categorical data was incorrectly reformatted when
write_index was False

closes pandas-dev#16923
  • Loading branch information
bashtage committed Jul 14, 2017
1 parent 6000c5b commit 308d79b
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 20 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.21.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ I/O

- Bug in :func:`read_csv` in which non integer values for the header argument generated an unhelpful / unrelated error message (:issue:`16338`)

- Bug in :func:`read_stata` where value labels could not be read when using an iterator (:issue:`16923`)

Plotting
^^^^^^^^
Expand Down
38 changes: 21 additions & 17 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,7 @@ def __init__(self, path_or_buf, convert_dates=True,
self.path_or_buf = BytesIO(contents)

self._read_header()
self._setup_dtype()

def __enter__(self):
""" enter context manager """
Expand Down Expand Up @@ -1299,6 +1300,23 @@ def _read_old_header(self, first_char):
# necessary data to continue parsing
self.data_location = self.path_or_buf.tell()

def _setup_dtype(self):
# Setup the dtype.
if self._dtype is not None:
return self._dtype

dtype = [] # Convert struct data types to numpy data type
for i, typ in enumerate(self.typlist):
if typ in self.NUMPY_TYPE_MAP:
dtype.append(('s' + str(i), self.byteorder +
self.NUMPY_TYPE_MAP[typ]))
else:
dtype.append(('s' + str(i), 'S' + str(typ)))
dtype = np.dtype(dtype)
self._dtype = dtype

return self._dtype

def _calcsize(self, fmt):
return (type(fmt) is int and fmt or
struct.calcsize(self.byteorder + fmt))
Expand Down Expand Up @@ -1472,24 +1490,12 @@ def read(self, nrows=None, convert_dates=None,
if nrows is None:
nrows = self.nobs

if (self.format_version >= 117) and (self._dtype is None):
if (self.format_version >= 117) and (not self._value_labels_read):
self._can_read_value_labels = True
self._read_strls()

# Setup the dtype.
if self._dtype is None:
dtype = [] # Convert struct data types to numpy data type
for i, typ in enumerate(self.typlist):
if typ in self.NUMPY_TYPE_MAP:
dtype.append(('s' + str(i), self.byteorder +
self.NUMPY_TYPE_MAP[typ]))
else:
dtype.append(('s' + str(i), 'S' + str(typ)))
dtype = np.dtype(dtype)
self._dtype = dtype

# Read data
dtype = self._dtype
dtype = self._setup_dtype()
max_read_len = (self.nobs - self._lines_read) * dtype.itemsize
read_len = nrows * dtype.itemsize
read_len = min(read_len, max_read_len)
Expand Down Expand Up @@ -1958,7 +1964,6 @@ def _prepare_categoricals(self, data):
return data

get_base_missing_value = StataMissingValue.get_base_missing_value
index = data.index
data_formatted = []
for col, col_is_cat in zip(data, is_cat):
if col_is_cat:
Expand All @@ -1981,8 +1986,7 @@ def _prepare_categoricals(self, data):

# Replace missing values with Stata missing value for type
values[values == -1] = get_base_missing_value(dtype)
data_formatted.append((col, values, index))

data_formatted.append((col, values))
else:
data_formatted.append((col, data[col]))
return DataFrame.from_items(data_formatted)
Expand Down
23 changes: 20 additions & 3 deletions pandas/tests/io/test_stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@
from datetime import datetime
from distutils.version import LooseVersion

import pytest
import numpy as np
import pandas as pd
import pandas.util.testing as tm
import pytest
from pandas import compat
from pandas._libs.tslib import NaT
from pandas.compat import iterkeys
from pandas.core.dtypes.common import is_categorical_dtype
from pandas.core.frame import DataFrame, Series
from pandas.io.parsers import read_csv
from pandas.io.stata import (read_stata, StataReader, InvalidColumnName,
PossiblePrecisionLoss, StataMissingValue)
from pandas._libs.tslib import NaT
from pandas.core.dtypes.common import is_categorical_dtype


class TestStata(object):
Expand Down Expand Up @@ -1297,3 +1297,20 @@ def test_pickle_path_localpath(self):
reader = lambda x: read_stata(x).set_index('index')
result = tm.round_trip_localpath(df.to_stata, reader)
tm.assert_frame_equal(df, result)

def test_value_labels_iterator(self):
# GH 16923
d = {'A': ['B', 'E', 'C', 'A', 'E']}
df = pd.DataFrame(data=d)
df['A'] = df['A'].astype('category')
with tm.ensure_clean() as path:
df.to_stata(path, write_index=False)
dta_iter = pd.read_stata('test.dta', iterator=True)
value_labels = dta_iter.value_labels()
assert value_labels == {'A': {0: 'A', 1: 'B', 2: 'C', 3: 'E'}}

with tm.ensure_clean() as path:
df.to_stata(path)
dta_iter = pd.read_stata('test.dta', iterator=True)
value_labels = dta_iter.value_labels()
assert value_labels == {'A': {0: 'A', 1: 'B', 2: 'C', 3: 'E'}}

0 comments on commit 308d79b

Please sign in to comment.