diff --git a/pandas/core/frame.py b/pandas/core/frame.py index c7db59949575b0..afe4f5368469b7 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -9,8 +9,8 @@ labeling information """ import collections -import datetime from collections import abc +import datetime from io import StringIO import itertools import sys @@ -1854,14 +1854,14 @@ def _from_arrays(cls, arrays, columns, index, dtype=None) -> "DataFrame": def to_stata( self, path: FilePathOrBuffer, - convert_dates: Optional[Dict[Hashable, str]] = None, + convert_dates: Optional[Dict[Label, str]] = None, write_index: bool = True, byteorder: Optional[str] = None, time_stamp: Optional[datetime.datetime] = None, data_label: Optional[str] = None, - variable_labels: Optional[Dict[Hashable, str]] = None, + variable_labels: Optional[Dict[Label, str]] = None, version: Optional[int] = 114, - convert_strl: Optional[Sequence[Hashable]] = None, + convert_strl: Optional[Sequence[Label]] = None, ) -> None: """ Export DataFrame object to Stata dta format. diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 5e906e88abd9b8..abfaced478ae1a 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -11,12 +11,14 @@ """ from collections import abc import datetime -from io import BytesIO +from io import BytesIO, IOBase +from pathlib import Path import os import struct import sys from typing import ( Any, + AnyStr, BinaryIO, Dict, Hashable, @@ -34,7 +36,7 @@ from pandas._libs.lib import infer_dtype from pandas._libs.writers import max_len_string_array -from pandas._typing import FilePathOrBuffer +from pandas._typing import FilePathOrBuffer, Label from pandas.util._decorators import Appender from pandas.core.dtypes.common import ( @@ -58,8 +60,6 @@ from pandas.io.common import get_filepath_or_buffer, stringify_path -BytesOrString = TypeVar("BytesOrString", str, bytes) - _version_error = ( "Version of given Stata file is {version}. pandas supports importing " "versions 104, 105, 108, 111 (Stata 7SE), 113 (Stata 8/9), " @@ -173,42 +173,6 @@ """ -@Appender(_read_stata_doc) -def read_stata( - filepath_or_buffer: FilePathOrBuffer, - convert_dates: bool = True, - convert_categoricals: bool = True, - index_col: Optional[str] = None, - convert_missing: bool = False, - preserve_dtypes: bool = True, - columns: Optional[Sequence[str]] = None, - order_categoricals: bool = True, - chunksize: Optional[int] = None, - iterator: bool = False, -) -> DataFrame: - - reader = StataReader( - filepath_or_buffer, - convert_dates=convert_dates, - convert_categoricals=convert_categoricals, - index_col=index_col, - convert_missing=convert_missing, - preserve_dtypes=preserve_dtypes, - columns=columns, - order_categoricals=order_categoricals, - chunksize=chunksize, - ) - - if iterator or chunksize: - data = reader - else: - try: - data = reader.read() - finally: - reader.close() - return data - - _date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"] @@ -363,7 +327,7 @@ def convert_delta_safe( conv_dates = convert_year_month_safe(year, month) elif fmt.startswith(("%tq", "tq")): # Delta quarters relative to base year = stata_epoch.year + dates // 4 - month = (dates % 4) * 3 + 1 + month = (dates % 4) * 3 + 1 # type: Series conv_dates = convert_year_month_safe(year, month) elif fmt.startswith(("%th", "th")): # Delta half-years relative to base year = stata_epoch.year + dates // 2 @@ -371,7 +335,7 @@ def convert_delta_safe( conv_dates = convert_year_month_safe(year, month) elif fmt.startswith(("%ty", "ty")): # Years -- not delta year = dates - month = np.ones_like(dates) + month = np.ones_like(dates) # type: np.ndarray conv_dates = convert_year_month_safe(year, month) else: raise ValueError(f"Date fmt {fmt} not understood") @@ -407,8 +371,8 @@ def parse_dates_safe( delta = dates - stata_epoch d["delta"] = delta.values.astype(np.int64) // 1000 # microseconds if days or year: - dates = DatetimeIndex(dates) - d["year"], d["month"] = dates.year, dates.month + date_index = DatetimeIndex(dates) + d["year"], d["month"] = date_index.year, date_index.month if days: days = dates.astype(np.int64) - to_datetime( d["year"], format="%Y" @@ -419,7 +383,7 @@ def parse_dates_safe( if delta: delta = dates.values - stata_epoch - def f(x: datetime.timedelta) -> int: + def f(x: datetime.timedelta) -> float: return US_PER_DAY * x.days + 1000000 * x.seconds + x.microseconds v = np.vectorize(f) @@ -430,10 +394,10 @@ def f(x: datetime.timedelta) -> int: d["month"] = year_month.values - d["year"] * 100 if days: - def f(x: datetime.datetime) -> int: + def g(x: datetime.datetime) -> int: return (x - datetime.datetime(x.year, 1, 1)).days - v = np.vectorize(f) + v = np.vectorize(d) d["days"] = v(dates) else: raise ValueError( @@ -630,7 +594,7 @@ class StataValueLabel: Encoding to use for value labels. """ - def __init__(self, catarray: Series, encoding: str = "latin-1") -> None: + def __init__(self, catarray: Series, encoding: str = "latin-1"): if encoding not in ("latin-1", "utf-8"): raise ValueError("Only latin-1 and utf-8 are supported.") @@ -696,7 +660,7 @@ def generate_value_label(self, byteorder: str) -> bytes: bio.write(struct.pack(byteorder + "i", self.len)) # labname - labname = self.labname[:32].encode(encoding) + labname = str(self.labname)[:32].encode(encoding) lab_len = 32 if encoding not in ("utf-8", "utf8") else 128 labname = _pad_bytes(labname, lab_len + 1) bio.write(labname) @@ -766,7 +730,7 @@ class StataMissingValue: """ # Construct a dictionary of missing values - MISSING_VALUES = {} + MISSING_VALUES: Dict[float, str] = {} bases = (101, 32741, 2147483621) for b in bases: # Conversion to long to avoid hash issues on 32 bit platforms #8968 @@ -777,21 +741,21 @@ class StataMissingValue: float32_base = b"\x00\x00\x00\x7f" increment = struct.unpack(" 0: - MISSING_VALUES[value] += chr(96 + i) - int_value = struct.unpack(" 0: - MISSING_VALUES[value] += chr(96 + i) - int_value = struct.unpack("q", struct.pack(" None: + def __init__(self, value: Union[int, float]): self._value = value # Conversion to int to avoid hash issues on 32 bit platforms #8968 value = int(value) if value < 2147483648 else float(value) @@ -863,7 +827,7 @@ def get_base_missing_value(cls, dtype: np.dtype) -> Union[int, float]: class StataParser: - def __init__(self) -> None: + def __init__(self): # type code. # -------------------- @@ -1040,9 +1004,9 @@ def __init__( columns: Optional[Sequence[str]] = None, order_categoricals: bool = True, chunksize: Optional[int] = None, - ) -> None: + ): super().__init__() - self.col_sizes = () + self.col_sizes: List[int] = [] # Arguments to the reader (can be temporarily overridden in # calls to read). @@ -1053,7 +1017,7 @@ def __init__( self._preserve_dtypes = preserve_dtypes self._columns = columns self._order_categoricals = order_categoricals - self._encoding = None + self._encoding = "" self._chunksize = chunksize # State variables for the file @@ -1063,7 +1027,9 @@ def __init__( self._column_selector_set = False self._value_labels_read = False self._data_read = False - self._dtype = None + # TODO: Place holder for mypy, is this right? + self._dtype = np.dtype("int") + self._dtype_setup = False self._lines_read = 0 self._native_byteorder = _set_endianness(sys.byteorder) @@ -1073,7 +1039,7 @@ def __init__( if isinstance(path_or_buf, (str, bytes)): self.path_or_buf = open(path_or_buf, "rb") - else: + elif isinstance(path_or_buf, IOBase): # Copy to BytesIO, and ensure no encoding contents = path_or_buf.read() self.path_or_buf = BytesIO(contents) @@ -1214,7 +1180,7 @@ def f(typ: int) -> Union[int, str]: typlist = [f(x) for x in raw_typlist] - def f(typ: int) -> Union[str, np.dtype]: + def g(typ: int) -> Union[str, np.dtype]: if typ <= 2045: return str(typ) try: @@ -1222,7 +1188,7 @@ def f(typ: int) -> Union[str, np.dtype]: except KeyError: raise ValueError(f"cannot convert stata dtype [{typ}]") - dtyplist = [f(x) for x in raw_typlist] + dtyplist = [g(x) for x in raw_typlist] return typlist, dtyplist @@ -1399,22 +1365,24 @@ def _read_old_header(self, first_char: bytes) -> None: def _setup_dtype(self) -> np.dtype: """Map between numpy and state dtypes""" - if self._dtype is not None: + if self._dtype_setup: return self._dtype - dtype = [] # Convert struct data types to numpy data type + dtypes = [] # Convert struct data types to numpy data type for i, typ in enumerate(self.typlist): if typ in self.NUMPY_TYPE_MAP: - dtype.append(("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ])) + dtypes.append(("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ])) else: - dtype.append(("s" + str(i), "S" + str(typ))) - dtype = np.dtype(dtype) - self._dtype = dtype + dtypes.append(("s" + str(i), "S" + str(typ))) + self._dtype = np.dtype(dtypes) + self._dtype_setup = True return self._dtype def _calcsize(self, fmt: Union[int, str]) -> int: - return type(fmt) is int and fmt or struct.calcsize(self.byteorder + fmt) + if isinstance(fmt, int): + return fmt + return struct.calcsize(self.byteorder + fmt) def _decode(self, s: bytes) -> str: # have bytes not strings, so must decode @@ -1440,7 +1408,7 @@ def _read_value_labels(self) -> None: if self.format_version <= 108: # Value labels are not supported in version 108 and earlier. self._value_labels_read = True - self.value_label_dict = dict() + self.value_label_dict: Dict[str, Dict[Union[float, int], str]] = {} return if self.format_version >= 117: @@ -1450,7 +1418,7 @@ def _read_value_labels(self) -> None: self.path_or_buf.seek(self.data_location + offset) self._value_labels_read = True - self.value_label_dict = dict() + self.value_label_dict = {} while True: if self.format_version >= 117: @@ -1510,9 +1478,12 @@ def _read_strls(self) -> None: length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] va = self.path_or_buf.read(length) if typ == 130: - va = va[0:-1].decode(self._encoding) - # Wrap v_o in a string to allow uint64 values as keys on 32bit OS - self.GSO[str(v_o)] = va + decoded_va = va[0:-1].decode(self._encoding) + else: + # Stata says typ 129 can be binary, so use str + decoded_va = str(va) + # Wrap v_o in a string to allow uint64 values as keys on 32bit OS + self.GSO[str(v_o)] = decoded_va def __next__(self) -> DataFrame: return self.read(nrows=self._chunksize or 1) @@ -1731,9 +1702,9 @@ def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFra replacements[colname] = replacement if replacements: columns = data.columns - replacements = DataFrame(replacements) - data = concat([data.drop(replacements.columns, 1), replacements], 1) - data = data[columns] + replacement_df = DataFrame(replacements) + replaced = concat([data.drop(replacement_df.columns, 1), replacement_df], 1) + data = replaced[columns] return data def _insert_strls(self, data: DataFrame) -> DataFrame: @@ -1805,8 +1776,8 @@ def _do_convert_categoricals( cat_data.categories = categories except ValueError: vc = Series(categories).value_counts() - repeats = list(vc.index[vc > 1]) - repeats = "-" * 80 + "\n" + "\n".join(repeats) + repeated_cats = list(vc.index[vc > 1]) + repeats = "-" * 80 + "\n" + "\n".join(repeated_cats) # GH 25772 msg = f""" Value labels for column {col} are not unique. These cannot be converted to @@ -1821,8 +1792,8 @@ def _do_convert_categoricals( """ raise ValueError(msg) # TODO: is the next line needed above in the data(...) method? - cat_data = Series(cat_data, index=data.index) - cat_converted_data.append((col, cat_data)) + cat_series = Series(cat_data, index=data.index) + cat_converted_data.append((col, cat_series)) else: cat_converted_data.append((col, data[col])) data = DataFrame.from_dict(dict(cat_converted_data)) @@ -1861,6 +1832,41 @@ def value_labels(self) -> Dict[str, Dict[Union[float, int], str]]: return self.value_label_dict +@Appender(_read_stata_doc) +def read_stata( + filepath_or_buffer: FilePathOrBuffer, + convert_dates: bool = True, + convert_categoricals: bool = True, + index_col: Optional[str] = None, + convert_missing: bool = False, + preserve_dtypes: bool = True, + columns: Optional[Sequence[str]] = None, + order_categoricals: bool = True, + chunksize: Optional[int] = None, + iterator: bool = False, +) -> Union[DataFrame, StataReader]: + + reader = StataReader( + filepath_or_buffer, + convert_dates=convert_dates, + convert_categoricals=convert_categoricals, + index_col=index_col, + convert_missing=convert_missing, + preserve_dtypes=preserve_dtypes, + columns=columns, + order_categoricals=order_categoricals, + chunksize=chunksize, + ) + + if iterator or chunksize: + return reader + + try: + data = reader.read() + finally: + reader.close() + return data + def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]: """ Open a binary file or no-op if file-like. @@ -1876,10 +1882,12 @@ def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]: own : bool True if the file was created, otherwise False """ - if hasattr(fname, "write"): - # if 'b' not in fname.mode: + if isinstance(fname, BinaryIO): return fname, False - return open(fname, "wb"), True + elif isinstance(fname, (str, Path)): + return open(fname, "wb"), True + else: + raise ValueError("fname must be a binary file, buffer or path-like.") def _set_endianness(endianness: str) -> str: @@ -1891,7 +1899,7 @@ def _set_endianness(endianness: str) -> str: raise ValueError(f"Endianness {endianness} not understood") -def _pad_bytes(name: BytesOrString, length: int) -> BytesOrString: +def _pad_bytes(name: AnyStr, length: int) -> AnyStr: """ Take a char string and pads it with null bytes until it's length chars. """ @@ -1925,8 +1933,7 @@ def _convert_datetime_to_stata_type(fmt: str) -> np.dtype: raise NotImplementedError(f"Format {fmt} not implemented") -# TODO -def _maybe_convert_to_int_keys(convert_dates: Dict, varlist) -> Dict: +def _maybe_convert_to_int_keys(convert_dates: Dict, varlist: List[Label]) -> Dict: new_dict = {} for key in convert_dates: if not convert_dates[key].startswith("%"): # make sure proper fmts @@ -2093,15 +2100,14 @@ def __init__( self, fname: FilePathOrBuffer, data: DataFrame, - convert_dates: Optional[Dict[Hashable, str]] = None, + convert_dates: Optional[Dict[Label, str]] = None, write_index: bool = True, byteorder: Optional[str] = None, time_stamp: Optional[datetime.datetime] = None, data_label: Optional[str] = None, - variable_labels: Optional[Dict[Hashable, str]] = None, - ) -> None: + variable_labels: Optional[Dict[Label, str]] = None, + ): super().__init__() - self._file = None self._convert_dates = {} if convert_dates is None else convert_dates self._write_index = write_index self._time_stamp = time_stamp @@ -2116,7 +2122,7 @@ def __init__( self._byteorder = _set_endianness(byteorder) self._fname = stringify_path(fname) self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8} - self._converted_names = {} + self._converted_names: Dict[Label, str] = {} def _write(self, to_write: str) -> None: """ @@ -2302,7 +2308,9 @@ def _prepare_pandas(self, data: DataFrame) -> None: data = data.copy() if self._write_index: - data = data.reset_index() + temp = data.reset_index() + if isinstance(temp, DataFrame): + data = temp # Ensure column names are strings data = self._check_column_names(data) @@ -2345,7 +2353,8 @@ def _prepare_pandas(self, data: DataFrame) -> None: # set the given format for the datetime cols if self._convert_dates is not None: for key in self._convert_dates: - self.fmtlist[key] = self._convert_dates[key] + if isinstance(key, int): + self.fmtlist[key] = self._convert_dates[key] def _encode_strings(self) -> None: """ @@ -2407,7 +2416,8 @@ def write_file(self) -> None: self._close() if self._own_file: try: - os.unlink(self._fname) + if isinstance(self._fname, (str, Path)): + os.unlink(self._fname) except OSError: warnings.warn( f"This save was not successful but {self._fname} could not " @@ -2479,9 +2489,9 @@ def _write_header( self._file.write(struct.pack(byteorder + "i", self.nobs)[:4]) # data label 81 bytes, char, null terminated if data_label is None: - self._file.write(self._null_terminate(_pad_bytes("", 80))) + self._file.write(self._null_terminate_bytes(_pad_bytes("", 80))) else: - self._file.write(self._null_terminate(_pad_bytes(data_label[:80], 80))) + self._file.write(self._null_terminate_bytes(_pad_bytes(data_label[:80], 80))) # time stamp, 18 bytes, char, null terminated # format dd Mon yyyy hh:mm if time_stamp is None: @@ -2510,7 +2520,7 @@ def _write_header( + month_lookup[time_stamp.month] + time_stamp.strftime(" %Y %H:%M") ) - self._file.write(self._null_terminate(ts)) + self._file.write(self._null_terminate_bytes(ts)) def _write_variable_types(self) -> None: for typ in self.typlist: @@ -2520,7 +2530,7 @@ def _write_varnames(self) -> None: # varlist names are checked by _check_column_names # varlist, requires null terminated for name in self.varlist: - name = self._null_terminate(name, True) + name = self._null_terminate_str(name) name = _pad_bytes(name[:32], 33) self._write(name) @@ -2540,7 +2550,7 @@ def _write_value_label_names(self) -> None: # Use variable name when categorical if self._is_col_cat[i]: name = self.varlist[i] - name = self._null_terminate(name, True) + name = self._null_terminate_str(name) name = _pad_bytes(name[:32], 33) self._write(name) else: # Default is empty label @@ -2609,15 +2619,14 @@ def _prepare_data(self) -> np.recarray: def _write_data(self, records: np.recarray) -> None: self._file.write(records.tobytes()) - def _null_terminate(self, s: str, as_string: bool = False) -> Union[str, bytes]: - null_byte = "\x00" - s += null_byte - - if not as_string: - s = s.encode(self._encoding) - + @staticmethod + def _null_terminate_str(s: str) -> str: + s += "\x00" return s + def _null_terminate_bytes(self, s: str) -> bytes: + return self._null_terminate_str(s).encode(self._encoding) + def _dtype_to_stata_type_117(dtype: np.dtype, column: Series, force_strl: bool) -> int: """ @@ -2705,7 +2714,7 @@ def __init__( columns: Sequence[str], version: int = 117, byteorder: Optional[str] = None, - ) -> None: + ): if version not in (117, 118, 119): raise ValueError("Only dta versions 117, 118 and 119 supported") self._dta_ver = version @@ -2933,16 +2942,18 @@ def __init__( self, fname: FilePathOrBuffer, data: DataFrame, - convert_dates: Optional[Dict[Hashable, str]] = None, + convert_dates: Optional[Dict[Label, str]] = None, write_index: bool = True, byteorder: Optional[str] = None, time_stamp: Optional[datetime.datetime] = None, data_label: Optional[str] = None, - variable_labels: Optional[Dict[Hashable, str]] = None, - convert_strl: Optional[Sequence[Hashable]] = None, - ) -> None: - # Shallow copy since convert_strl might be modified later - self._convert_strl = [] if convert_strl is None else convert_strl[:] + variable_labels: Optional[Dict[Label, str]] = None, + convert_strl: Optional[Sequence[Label]] = None, + ): + # Copy to new list since convert_strl might be modified later + self._convert_strl: List[Label] = [] + if convert_strl is not None: + self._convert_strl.extend(convert_strl) super().__init__( fname, @@ -2954,8 +2965,8 @@ def __init__( data_label=data_label, variable_labels=variable_labels, ) - self._map = None - self._strl_blob = None + self._map: Dict[str, int] = {} + self._strl_blob = b"" @staticmethod def _tag(val: Union[str, bytes], tag: str) -> bytes: @@ -2989,11 +3000,11 @@ def _write_header( bio.write(self._tag(struct.pack(byteorder + nobs_size, self.nobs), "N")) # data label 81 bytes, char, null terminated label = data_label[:80] if data_label is not None else "" - label = label.encode(self._encoding) + encoded_label = label.encode(self._encoding) label_size = "B" if self._dta_version == 117 else "H" label_len = struct.pack(byteorder + label_size, len(label)) - label = label_len + label - bio.write(self._tag(label, "label")) + encoded_label = label_len + encoded_label + bio.write(self._tag(encoded_label, "label")) # time stamp, 18 bytes, char, null terminated # format dd Mon yyyy hh:mm if time_stamp is None: @@ -3022,8 +3033,8 @@ def _write_header( + time_stamp.strftime(" %Y %H:%M") ) # '\x11' added due to inspection of Stata file - ts = b"\x11" + bytes(ts, "utf-8") - bio.write(self._tag(ts, "timestamp")) + stata_ts = b"\x11" + bytes(ts, "utf-8") + bio.write(self._tag(stata_ts, "timestamp")) bio.seek(0) self._file.write(self._tag(bio.read(), "header")) @@ -3031,7 +3042,7 @@ def _write_map(self) -> None: """Called twice during file write. The first populates the values in the map with 0s. The second call writes the final map locations when all blocks have been written.""" - if self._map is None: + if not self._map: self._map = dict( ( ("stata_data", 0), @@ -3072,7 +3083,7 @@ def _write_varnames(self) -> None: # 118 scales by 4 to accommodate utf-8 data worst case encoding vn_len = 32 if self._dta_version == 117 else 128 for name in self.varlist: - name = self._null_terminate(name, True) + name = self._null_terminate_str(name) name = _pad_bytes_new(name[:32].encode(self._encoding), vn_len + 1) bio.write(name) bio.seek(0) @@ -3102,9 +3113,9 @@ def _write_value_label_names(self) -> None: name = "" # default name if self._is_col_cat[i]: name = self.varlist[i] - name = self._null_terminate(name, True) - name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1) - bio.write(name) + name = self._null_terminate_str(name) + encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1) + bio.write(encoded_name) bio.seek(0) self._file.write(self._tag(bio.read(), "value_label_names")) @@ -3154,10 +3165,7 @@ def _write_data(self, records) -> None: def _write_strls(self) -> None: self._update_map("strls") - strls = b"" - if self._strl_blob is not None: - strls = self._strl_blob - self._file.write(self._tag(strls, "strls")) + self._file.write(self._tag(self._strl_blob, "strls")) def _write_expansion_fields(self) -> None: """No-op in dta 117+""" @@ -3310,15 +3318,15 @@ def __init__( self, fname: FilePathOrBuffer, data: DataFrame, - convert_dates: Optional[Dict[Hashable, str]] = None, + convert_dates: Optional[Dict[Label, str]] = None, write_index: bool = True, byteorder: Optional[str] = None, time_stamp: Optional[datetime.datetime] = None, data_label: Optional[str] = None, - variable_labels: Optional[Dict[Hashable, str]] = None, - convert_strl: Optional[Sequence[Hashable]] = None, + variable_labels: Optional[Dict[Label, str]] = None, + convert_strl: Optional[Sequence[Label]] = None, version: Optional[int] = None, - ) -> None: + ): if version is None: version = 118 if data.shape[1] <= 32767 else 119 elif version not in (118, 119):