diff --git a/pandas/core/frame.py b/pandas/core/frame.py index 012fb1d0c2eb7f..dbcdcd17255b56 100644 --- a/pandas/core/frame.py +++ b/pandas/core/frame.py @@ -11,6 +11,7 @@ import collections from collections import abc +import datetime from io import StringIO import itertools import sys @@ -19,6 +20,7 @@ IO, TYPE_CHECKING, Any, + Dict, FrozenSet, Hashable, Iterable, @@ -39,7 +41,7 @@ from pandas._config import get_option from pandas._libs import algos as libalgos, lib, properties -from pandas._typing import Axes, Axis, Dtype, FilePathOrBuffer, Level, Renamer +from pandas._typing import Axes, Axis, Dtype, FilePathOrBuffer, Label, Level, Renamer from pandas.compat import PY37 from pandas.compat._optional import import_optional_dependency from pandas.compat.numpy import function as nv @@ -1851,16 +1853,16 @@ def _from_arrays(cls, arrays, columns, index, dtype=None) -> "DataFrame": @deprecate_kwarg(old_arg_name="fname", new_arg_name="path") def to_stata( self, - path, - convert_dates=None, - write_index=True, - byteorder=None, - time_stamp=None, - data_label=None, - variable_labels=None, - version=114, - convert_strl=None, - ): + path: FilePathOrBuffer, + convert_dates: Optional[Dict[Label, str]] = None, + write_index: bool = True, + byteorder: Optional[str] = None, + time_stamp: Optional[datetime.datetime] = None, + data_label: Optional[str] = None, + variable_labels: Optional[Dict[Label, str]] = None, + version: Optional[int] = 114, + convert_strl: Optional[Sequence[Label]] = None, + ) -> None: """ Export DataFrame object to Stata dta format. @@ -1954,11 +1956,13 @@ def to_stata( raise ValueError("strl is not supported in format 114") from pandas.io.stata import StataWriter as statawriter elif version == 117: - from pandas.io.stata import StataWriter117 as statawriter + # mypy: Name 'statawriter' already defined (possibly by an import) + from pandas.io.stata import StataWriter117 as statawriter # type: ignore else: # versions 118 and 119 - from pandas.io.stata import StataWriterUTF8 as statawriter + # mypy: Name 'statawriter' already defined (possibly by an import) + from pandas.io.stata import StataWriterUTF8 as statawriter # type:ignore - kwargs = {} + kwargs: Dict[str, Any] = {} if version is None or version >= 117: # strl conversion is only supported >= 117 kwargs["convert_strl"] = convert_strl @@ -1966,7 +1970,8 @@ def to_stata( # Specifying the version is only supported for UTF8 (118 or 119) kwargs["version"] = version - writer = statawriter( + # mypy: Too many arguments for "StataWriter" + writer = statawriter( # type: ignore path, self, convert_dates=convert_dates, diff --git a/pandas/io/common.py b/pandas/io/common.py index cf19169214c351..00f2961e416179 100644 --- a/pandas/io/common.py +++ b/pandas/io/common.py @@ -160,10 +160,9 @@ def get_filepath_or_buffer( Returns ------- - tuple of ({a filepath_ or buffer or S3File instance}, - encoding, str, - compression, str, - should_close, bool) + Tuple[FilePathOrBuffer, str, str, bool] + Tuple containing the filepath or buffer, the encoding, the compression + and should_close. """ filepath_or_buffer = stringify_path(filepath_or_buffer) diff --git a/pandas/io/stata.py b/pandas/io/stata.py index ec200a1ad8409e..dd22a20860703c 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -11,11 +11,12 @@ """ from collections import abc import datetime -from io import BytesIO +from io import BytesIO, IOBase import os +from pathlib import Path import struct import sys -from typing import Any, Dict, Hashable, Optional, Sequence +from typing import Any, AnyStr, BinaryIO, Dict, List, Optional, Sequence, Tuple, Union import warnings from dateutil.relativedelta import relativedelta @@ -23,7 +24,7 @@ from pandas._libs.lib import infer_dtype from pandas._libs.writers import max_len_string_array -from pandas._typing import FilePathOrBuffer +from pandas._typing import FilePathOrBuffer, Label from pandas.util._decorators import Appender from pandas.core.dtypes.common import ( @@ -160,49 +161,16 @@ """ -@Appender(_read_stata_doc) -def read_stata( - filepath_or_buffer, - convert_dates=True, - convert_categoricals=True, - index_col=None, - convert_missing=False, - preserve_dtypes=True, - columns=None, - order_categoricals=True, - chunksize=None, - iterator=False, -): - - reader = StataReader( - filepath_or_buffer, - convert_dates=convert_dates, - convert_categoricals=convert_categoricals, - index_col=index_col, - convert_missing=convert_missing, - preserve_dtypes=preserve_dtypes, - columns=columns, - order_categoricals=order_categoricals, - chunksize=chunksize, - ) - - if iterator or chunksize: - data = reader - else: - try: - data = reader.read() - finally: - reader.close() - return data - - _date_formats = ["%tc", "%tC", "%td", "%d", "%tw", "%tm", "%tq", "%th", "%ty"] stata_epoch = datetime.datetime(1960, 1, 1) -def _stata_elapsed_date_to_datetime_vec(dates, fmt): +# TODO: Add typing. As of January 2020 it is not possible to type this function since +# mypy doesn't understand that a Series and an int can be combined using mathematical +# operations. (+, -). +def _stata_elapsed_date_to_datetime_vec(dates, fmt) -> Series: """ Convert from SIF to datetime. https://www.stata.com/help.cgi?datetime @@ -247,9 +215,6 @@ def _stata_elapsed_date_to_datetime_vec(dates, fmt): half-years since 1960h1 yearly date - ty years since 0000 - - If you don't have pandas with datetime support, then you can't do - milliseconds accurately. """ MIN_YEAR, MAX_YEAR = Timestamp.min.year, Timestamp.max.year MAX_DAY_DELTA = (Timestamp.max - datetime.datetime(1960, 1, 1)).days @@ -257,7 +222,7 @@ def _stata_elapsed_date_to_datetime_vec(dates, fmt): MIN_MS_DELTA = MIN_DAY_DELTA * 24 * 3600 * 1000 MAX_MS_DELTA = MAX_DAY_DELTA * 24 * 3600 * 1000 - def convert_year_month_safe(year, month): + def convert_year_month_safe(year, month) -> Series: """ Convert year and month to datetimes, using pandas vectorized versions when the date range falls within the range supported by pandas. @@ -272,7 +237,7 @@ def convert_year_month_safe(year, month): [datetime.datetime(y, m, 1) for y, m in zip(year, month)], index=index ) - def convert_year_days_safe(year, days): + def convert_year_days_safe(year, days) -> Series: """ Converts year (e.g. 1999) and days since the start of the year to a datetime or datetime64 Series @@ -287,7 +252,7 @@ def convert_year_days_safe(year, days): ] return Series(value, index=index) - def convert_delta_safe(base, deltas, unit): + def convert_delta_safe(base, deltas, unit) -> Series: """ Convert base dates and deltas to datetimes, using pandas vectorized versions if the deltas satisfy restrictions required to be expressed @@ -348,16 +313,16 @@ def convert_delta_safe(base, deltas, unit): conv_dates = convert_year_month_safe(year, month) elif fmt.startswith(("%tq", "tq")): # Delta quarters relative to base year = stata_epoch.year + dates // 4 - month = (dates % 4) * 3 + 1 - conv_dates = convert_year_month_safe(year, month) + quarter_month = (dates % 4) * 3 + 1 + conv_dates = convert_year_month_safe(year, quarter_month) elif fmt.startswith(("%th", "th")): # Delta half-years relative to base year = stata_epoch.year + dates // 2 month = (dates % 2) * 6 + 1 conv_dates = convert_year_month_safe(year, month) elif fmt.startswith(("%ty", "ty")): # Years -- not delta year = dates - month = np.ones_like(dates) - conv_dates = convert_year_month_safe(year, month) + first_month = np.ones_like(dates) + conv_dates = convert_year_month_safe(year, first_month) else: raise ValueError(f"Date fmt {fmt} not understood") @@ -367,7 +332,7 @@ def convert_delta_safe(base, deltas, unit): return conv_dates -def _datetime_to_stata_elapsed_vec(dates, fmt): +def _datetime_to_stata_elapsed_vec(dates: Series, fmt: str) -> Series: """ Convert from datetime to SIF. https://www.stata.com/help.cgi?datetime @@ -387,21 +352,26 @@ def parse_dates_safe(dates, delta=False, year=False, days=False): d = {} if is_datetime64_dtype(dates.values): if delta: - delta = dates - stata_epoch - d["delta"] = delta.values.astype(np.int64) // 1000 # microseconds + time_delta = dates - stata_epoch + d["delta"] = time_delta.values.astype(np.int64) // 1000 # microseconds if days or year: - dates = DatetimeIndex(dates) - d["year"], d["month"] = dates.year, dates.month + # ignore since mypy reports that DatetimeIndex has no year/month + date_index = DatetimeIndex(dates) + d["year"] = date_index.year # type: ignore + d["month"] = date_index.month # type: ignore if days: - days = dates.astype(np.int64) - to_datetime( + days_in_ns = dates.astype(np.int64) - to_datetime( d["year"], format="%Y" ).astype(np.int64) - d["days"] = days // NS_PER_DAY + d["days"] = days_in_ns // NS_PER_DAY elif infer_dtype(dates, skipna=False) == "datetime": if delta: delta = dates.values - stata_epoch - f = lambda x: US_PER_DAY * x.days + 1000000 * x.seconds + x.microseconds + + def f(x: datetime.timedelta) -> float: + return US_PER_DAY * x.days + 1000000 * x.seconds + x.microseconds + v = np.vectorize(f) d["delta"] = v(delta) if year: @@ -409,8 +379,11 @@ def parse_dates_safe(dates, delta=False, year=False, days=False): d["year"] = year_month.values // 100 d["month"] = year_month.values - d["year"] * 100 if days: - f = lambda x: (x - datetime.datetime(x.year, 1, 1)).days - v = np.vectorize(f) + + def g(x: datetime.datetime) -> int: + return (x - datetime.datetime(x.year, 1, 1)).days + + v = np.vectorize(g) d["days"] = v(dates) else: raise ValueError( @@ -507,7 +480,7 @@ class InvalidColumnName(Warning): """ -def _cast_to_stata_types(data): +def _cast_to_stata_types(data: DataFrame) -> DataFrame: """Checks the dtypes of the columns of a pandas DataFrame for compatibility with the data types and ranges supported by Stata, and converts if necessary. @@ -601,13 +574,13 @@ class StataValueLabel: Parameters ---------- - catarray : Categorical + catarray : Series Categorical Series to encode encoding : {"latin-1", "utf-8"} Encoding to use for value labels. """ - def __init__(self, catarray, encoding="latin-1"): + def __init__(self, catarray: Series, encoding: str = "latin-1"): if encoding not in ("latin-1", "utf-8"): raise ValueError("Only latin-1 and utf-8 are supported.") @@ -616,10 +589,10 @@ def __init__(self, catarray, encoding="latin-1"): categories = catarray.cat.categories self.value_labels = list(zip(np.arange(len(categories)), categories)) self.value_labels.sort(key=lambda x: x[0]) - self.text_len = np.int32(0) - self.off = [] - self.val = [] - self.txt = [] + self.text_len = 0 + self.off: List[int] = [] + self.val: List[int] = [] + self.txt: List[bytes] = [] self.n = 0 # Compute lengths and setup lists of offsets and labels @@ -651,13 +624,7 @@ def __init__(self, catarray, encoding="latin-1"): # Total length self.len = 4 + 4 + 4 * self.n + 4 * self.n + self.text_len - def _encode(self, s): - """ - Python 3 compatibility shim - """ - return s.encode(self._encoding) - - def generate_value_label(self, byteorder): + def generate_value_label(self, byteorder: str) -> bytes: """ Generate the binary representation of the value labals. @@ -679,7 +646,7 @@ def generate_value_label(self, byteorder): bio.write(struct.pack(byteorder + "i", self.len)) # labname - labname = self.labname[:32].encode(encoding) + labname = str(self.labname)[:32].encode(encoding) lab_len = 32 if encoding not in ("utf-8", "utf8") else 128 labname = _pad_bytes(labname, lab_len + 1) bio.write(labname) @@ -717,16 +684,9 @@ class StataMissingValue: Parameters ---------- - value : int8, int16, int32, float32 or float64 + value : {int, float} The Stata missing value code - Attributes - ---------- - string : string - String representation of the Stata missing value - value : int8, int16, int32, float32 or float64 - The original encoded missing value - Notes ----- More information: @@ -756,7 +716,7 @@ class StataMissingValue: """ # Construct a dictionary of missing values - MISSING_VALUES = {} + MISSING_VALUES: Dict[float, str] = {} bases = (101, 32741, 2147483621) for b in bases: # Conversion to long to avoid hash issues on 32 bit platforms #8968 @@ -767,21 +727,21 @@ class StataMissingValue: float32_base = b"\x00\x00\x00\x7f" increment = struct.unpack(" 0: - MISSING_VALUES[value] += chr(96 + i) - int_value = struct.unpack(" 0: - MISSING_VALUES[value] += chr(96 + i) - int_value = struct.unpack("q", struct.pack(" str: + """ + The Stata representation of the missing value: '.', '.a'..'.z' + + Returns + ------- + str + The representation of the missing value. + """ + return self._str + + @property + def value(self) -> Union[int, float]: + """ + The binary representation of the missing value. + + Returns + ------- + {int, float} + The binary representation of the missing value. + """ + return self._value def __str__(self) -> str: return self.string @@ -820,7 +796,7 @@ def __eq__(self, other: Any) -> bool: ) @classmethod - def get_base_missing_value(cls, dtype): + def get_base_missing_value(cls, dtype: np.dtype) -> Union[int, float]: if dtype == np.int8: value = cls.BASE_MISSING_VALUES["int8"] elif dtype == np.int16: @@ -1005,18 +981,18 @@ class StataReader(StataParser, abc.Iterator): def __init__( self, - path_or_buf, - convert_dates=True, - convert_categoricals=True, - index_col=None, - convert_missing=False, - preserve_dtypes=True, - columns=None, - order_categoricals=True, - chunksize=None, + path_or_buf: FilePathOrBuffer, + convert_dates: bool = True, + convert_categoricals: bool = True, + index_col: Optional[str] = None, + convert_missing: bool = False, + preserve_dtypes: bool = True, + columns: Optional[Sequence[str]] = None, + order_categoricals: bool = True, + chunksize: Optional[int] = None, ): super().__init__() - self.col_sizes = () + self.col_sizes: List[int] = [] # Arguments to the reader (can be temporarily overridden in # calls to read). @@ -1027,7 +1003,7 @@ def __init__( self._preserve_dtypes = preserve_dtypes self._columns = columns self._order_categoricals = order_categoricals - self._encoding = None + self._encoding = "" self._chunksize = chunksize # State variables for the file @@ -1047,7 +1023,7 @@ def __init__( if isinstance(path_or_buf, (str, bytes)): self.path_or_buf = open(path_or_buf, "rb") - else: + elif isinstance(path_or_buf, IOBase): # Copy to BytesIO, and ensure no encoding contents = path_or_buf.read() self.path_or_buf = BytesIO(contents) @@ -1055,22 +1031,22 @@ def __init__( self._read_header() self._setup_dtype() - def __enter__(self): + def __enter__(self) -> "StataReader": """ enter context manager """ return self - def __exit__(self, exc_type, exc_value, traceback): + def __exit__(self, exc_type, exc_value, traceback) -> None: """ exit context manager """ self.close() - def close(self): + def close(self) -> None: """ close the handle if its open """ try: self.path_or_buf.close() except IOError: pass - def _set_encoding(self): + def _set_encoding(self) -> None: """ Set string encoding which depends on file version """ @@ -1079,10 +1055,10 @@ def _set_encoding(self): else: self._encoding = "utf-8" - def _read_header(self): + def _read_header(self) -> None: first_char = self.path_or_buf.read(1) if struct.unpack("c", first_char)[0] == b"<": - self._read_new_header(first_char) + self._read_new_header() else: self._read_old_header(first_char) @@ -1091,7 +1067,7 @@ def _read_header(self): # calculate size of a data record self.col_sizes = [self._calcsize(typ) for typ in self.typlist] - def _read_new_header(self, first_char): + def _read_new_header(self) -> None: # The first part of the header is common to 117 - 119. self.path_or_buf.read(27) # stata_dta>
self.format_version = int(self.path_or_buf.read(3)) @@ -1168,15 +1144,17 @@ def _read_new_header(self, first_char): self._variable_labels = self._get_variable_labels() # Get data type information, works for versions 117-119. - def _get_dtypes(self, seek_vartypes): + def _get_dtypes( + self, seek_vartypes: int + ) -> Tuple[List[Union[int, str]], List[Union[int, np.dtype]]]: self.path_or_buf.seek(seek_vartypes) raw_typlist = [ struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0] - for i in range(self.nvar) + for _ in range(self.nvar) ] - def f(typ): + def f(typ: int) -> Union[int, str]: if typ <= 2045: return typ try: @@ -1186,7 +1164,7 @@ def f(typ): typlist = [f(x) for x in raw_typlist] - def f(typ): + def g(typ: int) -> Union[str, np.dtype]: if typ <= 2045: return str(typ) try: @@ -1194,20 +1172,17 @@ def f(typ): except KeyError: raise ValueError(f"cannot convert stata dtype [{typ}]") - dtyplist = [f(x) for x in raw_typlist] + dtyplist = [g(x) for x in raw_typlist] return typlist, dtyplist - def _get_varlist(self): - if self.format_version == 117: - b = 33 - elif self.format_version >= 118: - b = 129 - - return [self._decode(self.path_or_buf.read(b)) for i in range(self.nvar)] + def _get_varlist(self) -> List[str]: + # 33 in order formats, 129 in formats 118 and 119 + b = 33 if self.format_version < 118 else 129 + return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] # Returns the format list - def _get_fmtlist(self): + def _get_fmtlist(self) -> List[str]: if self.format_version >= 118: b = 57 elif self.format_version > 113: @@ -1217,40 +1192,40 @@ def _get_fmtlist(self): else: b = 7 - return [self._decode(self.path_or_buf.read(b)) for i in range(self.nvar)] + return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] # Returns the label list - def _get_lbllist(self): + def _get_lbllist(self) -> List[str]: if self.format_version >= 118: b = 129 elif self.format_version > 108: b = 33 else: b = 9 - return [self._decode(self.path_or_buf.read(b)) for i in range(self.nvar)] + return [self._decode(self.path_or_buf.read(b)) for _ in range(self.nvar)] - def _get_variable_labels(self): + def _get_variable_labels(self) -> List[str]: if self.format_version >= 118: vlblist = [ - self._decode(self.path_or_buf.read(321)) for i in range(self.nvar) + self._decode(self.path_or_buf.read(321)) for _ in range(self.nvar) ] elif self.format_version > 105: vlblist = [ - self._decode(self.path_or_buf.read(81)) for i in range(self.nvar) + self._decode(self.path_or_buf.read(81)) for _ in range(self.nvar) ] else: vlblist = [ - self._decode(self.path_or_buf.read(32)) for i in range(self.nvar) + self._decode(self.path_or_buf.read(32)) for _ in range(self.nvar) ] return vlblist - def _get_nobs(self): + def _get_nobs(self) -> int: if self.format_version >= 118: return struct.unpack(self.byteorder + "Q", self.path_or_buf.read(8))[0] else: return struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] - def _get_data_label(self): + def _get_data_label(self) -> str: if self.format_version >= 118: strlen = struct.unpack(self.byteorder + "H", self.path_or_buf.read(2))[0] return self._decode(self.path_or_buf.read(strlen)) @@ -1262,7 +1237,7 @@ def _get_data_label(self): else: return self._decode(self.path_or_buf.read(32)) - def _get_time_stamp(self): + def _get_time_stamp(self) -> str: if self.format_version >= 118: strlen = struct.unpack("b", self.path_or_buf.read(1))[0] return self.path_or_buf.read(strlen).decode("utf-8") @@ -1274,9 +1249,9 @@ def _get_time_stamp(self): else: raise ValueError() - def _get_seek_variable_labels(self): + def _get_seek_variable_labels(self) -> int: if self.format_version == 117: - self.path_or_buf.read(8) # , throw away + self.path_or_buf.read(8) # , throw away # Stata 117 data files do not follow the described format. This is # a work around that uses the previous label, 33 bytes for each # variable, 20 for the closing tag and 17 for the opening tag @@ -1286,7 +1261,7 @@ def _get_seek_variable_labels(self): else: raise ValueError() - def _read_old_header(self, first_char): + def _read_old_header(self, first_char: bytes) -> None: self.format_version = struct.unpack("b", first_char)[0] if self.format_version not in [104, 105, 108, 111, 113, 114, 115]: raise ValueError(_version_error.format(version=self.format_version)) @@ -1306,7 +1281,7 @@ def _read_old_header(self, first_char): # descriptors if self.format_version > 108: - typlist = [ord(self.path_or_buf.read(1)) for i in range(self.nvar)] + typlist = [ord(self.path_or_buf.read(1)) for _ in range(self.nvar)] else: buf = self.path_or_buf.read(self.nvar) typlistb = np.frombuffer(buf, dtype=np.uint8) @@ -1330,11 +1305,11 @@ def _read_old_header(self, first_char): if self.format_version > 108: self.varlist = [ - self._decode(self.path_or_buf.read(33)) for i in range(self.nvar) + self._decode(self.path_or_buf.read(33)) for _ in range(self.nvar) ] else: self.varlist = [ - self._decode(self.path_or_buf.read(9)) for i in range(self.nvar) + self._decode(self.path_or_buf.read(9)) for _ in range(self.nvar) ] self.srtlist = struct.unpack( self.byteorder + ("h" * (self.nvar + 1)), @@ -1372,26 +1347,27 @@ def _read_old_header(self, first_char): # necessary data to continue parsing self.data_location = self.path_or_buf.tell() - def _setup_dtype(self): + def _setup_dtype(self) -> np.dtype: """Map between numpy and state dtypes""" if self._dtype is not None: return self._dtype - dtype = [] # Convert struct data types to numpy data type + dtypes = [] # Convert struct data types to numpy data type for i, typ in enumerate(self.typlist): if typ in self.NUMPY_TYPE_MAP: - dtype.append(("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ])) + dtypes.append(("s" + str(i), self.byteorder + self.NUMPY_TYPE_MAP[typ])) else: - dtype.append(("s" + str(i), "S" + str(typ))) - dtype = np.dtype(dtype) - self._dtype = dtype + dtypes.append(("s" + str(i), "S" + str(typ))) + self._dtype = np.dtype(dtypes) return self._dtype - def _calcsize(self, fmt): - return type(fmt) is int and fmt or struct.calcsize(self.byteorder + fmt) + def _calcsize(self, fmt: Union[int, str]) -> int: + if isinstance(fmt, int): + return fmt + return struct.calcsize(self.byteorder + fmt) - def _decode(self, s): + def _decode(self, s: bytes) -> str: # have bytes not strings, so must decode s = s.partition(b"\0")[0] try: @@ -1408,24 +1384,25 @@ def _decode(self, s): warnings.warn(msg, UnicodeWarning) return s.decode("latin-1") - def _read_value_labels(self): + def _read_value_labels(self) -> None: if self._value_labels_read: # Don't read twice return if self.format_version <= 108: # Value labels are not supported in version 108 and earlier. self._value_labels_read = True - self.value_label_dict = dict() + self.value_label_dict: Dict[str, Dict[Union[float, int], str]] = {} return if self.format_version >= 117: self.path_or_buf.seek(self.seek_value_labels) else: + assert self._dtype is not None offset = self.nobs * self._dtype.itemsize self.path_or_buf.seek(self.data_location + offset) self._value_labels_read = True - self.value_label_dict = dict() + self.value_label_dict = {} while True: if self.format_version >= 117: @@ -1461,7 +1438,7 @@ def _read_value_labels(self): self.path_or_buf.read(6) # self._value_labels_read = True - def _read_strls(self): + def _read_strls(self) -> None: self.path_or_buf.seek(self.seek_strls) # Wrap v_o in a string to allow uint64 values as keys on 32bit OS self.GSO = {"0": ""} @@ -1476,23 +1453,26 @@ def _read_strls(self): # Only tested on little endian file on little endian machine. v_size = 2 if self.format_version == 118 else 3 if self.byteorder == "<": - buf = buf[0:v_size] + buf[4 : 12 - v_size] + buf = buf[0:v_size] + buf[4 : (12 - v_size)] else: # This path may not be correct, impossible to test - buf = buf[0:v_size] + buf[4 + v_size :] + buf = buf[0:v_size] + buf[(4 + v_size) :] v_o = struct.unpack("Q", buf)[0] typ = struct.unpack("B", self.path_or_buf.read(1))[0] length = struct.unpack(self.byteorder + "I", self.path_or_buf.read(4))[0] va = self.path_or_buf.read(length) if typ == 130: - va = va[0:-1].decode(self._encoding) - # Wrap v_o in a string to allow uint64 values as keys on 32bit OS - self.GSO[str(v_o)] = va + decoded_va = va[0:-1].decode(self._encoding) + else: + # Stata says typ 129 can be binary, so use str + decoded_va = str(va) + # Wrap v_o in a string to allow uint64 values as keys on 32bit OS + self.GSO[str(v_o)] = decoded_va - def __next__(self): + def __next__(self) -> DataFrame: return self.read(nrows=self._chunksize or 1) - def get_chunk(self, size=None): + def get_chunk(self, size: Optional[int] = None) -> DataFrame: """ Reads lines from Stata file and returns as dataframe @@ -1512,15 +1492,15 @@ def get_chunk(self, size=None): @Appender(_read_method_doc) def read( self, - nrows=None, - convert_dates=None, - convert_categoricals=None, - index_col=None, - convert_missing=None, - preserve_dtypes=None, - columns=None, - order_categoricals=None, - ): + nrows: Optional[int] = None, + convert_dates: Optional[bool] = None, + convert_categoricals: Optional[bool] = None, + index_col: Optional[str] = None, + convert_missing: Optional[bool] = None, + preserve_dtypes: Optional[bool] = None, + columns: Optional[Sequence[str]] = None, + order_categoricals: Optional[bool] = None, + ) -> DataFrame: # Handle empty file or chunk. If reading incrementally raise # StopIteration. If reading the whole thing return an empty # data frame. @@ -1554,6 +1534,7 @@ def read( self._read_strls() # Read data + assert self._dtype is not None dtype = self._dtype max_read_len = (self.nobs - self._lines_read) * dtype.itemsize read_len = nrows * dtype.itemsize @@ -1673,7 +1654,7 @@ def any_startswith(x: str) -> bool: return data - def _do_convert_missing(self, data, convert_missing): + def _do_convert_missing(self, data: DataFrame, convert_missing: bool) -> DataFrame: # Check for missing values, and replace if found replacements = {} for i, colname in enumerate(data): @@ -1689,7 +1670,6 @@ def _do_convert_missing(self, data, convert_missing): continue if convert_missing: # Replacement follows Stata notation - missing_loc = np.argwhere(missing._ndarray_values) umissing, umissing_loc = np.unique(series[missing], return_inverse=True) replacement = Series(series, dtype=np.object) @@ -1707,12 +1687,12 @@ def _do_convert_missing(self, data, convert_missing): replacements[colname] = replacement if replacements: columns = data.columns - replacements = DataFrame(replacements) - data = concat([data.drop(replacements.columns, 1), replacements], 1) - data = data[columns] + replacement_df = DataFrame(replacements) + replaced = concat([data.drop(replacement_df.columns, 1), replacement_df], 1) + data = replaced[columns] return data - def _insert_strls(self, data): + def _insert_strls(self, data: DataFrame) -> DataFrame: if not hasattr(self, "GSO") or len(self.GSO) == 0: return data for i, typ in enumerate(self.typlist): @@ -1722,7 +1702,7 @@ def _insert_strls(self, data): data.iloc[:, i] = [self.GSO[str(k)] for k in data.iloc[:, i]] return data - def _do_select_columns(self, data, columns): + def _do_select_columns(self, data: DataFrame, columns: Sequence[str]) -> DataFrame: if not self._column_selector_set: column_set = set(columns) @@ -1755,9 +1735,13 @@ def _do_select_columns(self, data, columns): return data[columns] + @staticmethod def _do_convert_categoricals( - self, data, value_label_dict, lbllist, order_categoricals - ): + data: DataFrame, + value_label_dict: Dict[str, Dict[Union[float, int], str]], + lbllist: Sequence[str], + order_categoricals: bool, + ) -> DataFrame: """ Converts categorical columns to Categorical type. """ @@ -1777,8 +1761,8 @@ def _do_convert_categoricals( cat_data.categories = categories except ValueError: vc = Series(categories).value_counts() - repeats = list(vc.index[vc > 1]) - repeats = "-" * 80 + "\n" + "\n".join(repeats) + repeated_cats = list(vc.index[vc > 1]) + repeats = "-" * 80 + "\n" + "\n".join(repeated_cats) # GH 25772 msg = f""" Value labels for column {col} are not unique. These cannot be converted to @@ -1793,21 +1777,21 @@ def _do_convert_categoricals( """ raise ValueError(msg) # TODO: is the next line needed above in the data(...) method? - cat_data = Series(cat_data, index=data.index) - cat_converted_data.append((col, cat_data)) + cat_series = Series(cat_data, index=data.index) + cat_converted_data.append((col, cat_series)) else: cat_converted_data.append((col, data[col])) data = DataFrame.from_dict(dict(cat_converted_data)) return data @property - def data_label(self): + def data_label(self) -> str: """ Return data label of Stata file. """ return self._data_label - def variable_labels(self): + def variable_labels(self) -> Dict[str, str]: """ Return variable labels as a dict, associating each variable name with corresponding label. @@ -1818,7 +1802,7 @@ def variable_labels(self): """ return dict(zip(self.varlist, self._variable_labels)) - def value_labels(self): + def value_labels(self) -> Dict[str, Dict[Union[float, int], str]]: """ Return a dict, associating each variable name a dict, associating each value its corresponding label. @@ -1833,7 +1817,43 @@ def value_labels(self): return self.value_label_dict -def _open_file_binary_write(fname): +@Appender(_read_stata_doc) +def read_stata( + filepath_or_buffer: FilePathOrBuffer, + convert_dates: bool = True, + convert_categoricals: bool = True, + index_col: Optional[str] = None, + convert_missing: bool = False, + preserve_dtypes: bool = True, + columns: Optional[Sequence[str]] = None, + order_categoricals: bool = True, + chunksize: Optional[int] = None, + iterator: bool = False, +) -> Union[DataFrame, StataReader]: + + reader = StataReader( + filepath_or_buffer, + convert_dates=convert_dates, + convert_categoricals=convert_categoricals, + index_col=index_col, + convert_missing=convert_missing, + preserve_dtypes=preserve_dtypes, + columns=columns, + order_categoricals=order_categoricals, + chunksize=chunksize, + ) + + if iterator or chunksize: + return reader + + try: + data = reader.read() + finally: + reader.close() + return data + + +def _open_file_binary_write(fname: FilePathOrBuffer) -> Tuple[BinaryIO, bool]: """ Open a binary file or no-op if file-like. @@ -1849,12 +1869,15 @@ def _open_file_binary_write(fname): True if the file was created, otherwise False """ if hasattr(fname, "write"): - # if 'b' not in fname.mode: - return fname, False - return open(fname, "wb"), True + # See https://github.com/python/mypy/issues/1424 for hasattr challenges + return fname, False # type: ignore + elif isinstance(fname, (str, Path)): + return open(fname, "wb"), True + else: + raise TypeError("fname must be a binary file, buffer or path-like.") -def _set_endianness(endianness): +def _set_endianness(endianness: str) -> str: if endianness.lower() in ["<", "little"]: return "<" elif endianness.lower() in [">", "big"]: @@ -1863,7 +1886,7 @@ def _set_endianness(endianness): raise ValueError(f"Endianness {endianness} not understood") -def _pad_bytes(name, length): +def _pad_bytes(name: AnyStr, length: int) -> AnyStr: """ Take a char string and pads it with null bytes until it's length chars. """ @@ -1872,7 +1895,7 @@ def _pad_bytes(name, length): return name + "\x00" * (length - len(name)) -def _convert_datetime_to_stata_type(fmt): +def _convert_datetime_to_stata_type(fmt: str) -> np.dtype: """ Convert from one of the stata date formats to a type in TYPE_MAP. """ @@ -1897,7 +1920,7 @@ def _convert_datetime_to_stata_type(fmt): raise NotImplementedError(f"Format {fmt} not implemented") -def _maybe_convert_to_int_keys(convert_dates, varlist): +def _maybe_convert_to_int_keys(convert_dates: Dict, varlist: List[Label]) -> Dict: new_dict = {} for key in convert_dates: if not convert_dates[key].startswith("%"): # make sure proper fmts @@ -1911,7 +1934,7 @@ def _maybe_convert_to_int_keys(convert_dates, varlist): return new_dict -def _dtype_to_stata_type(dtype, column): +def _dtype_to_stata_type(dtype: np.dtype, column: Series) -> int: """ Convert dtype types to stata types. Returns the byte of the given ordinal. See TYPE_MAP and comments for an explanation. This is also explained in @@ -1947,7 +1970,9 @@ def _dtype_to_stata_type(dtype, column): raise NotImplementedError(f"Data type {dtype} not supported.") -def _dtype_to_default_stata_fmt(dtype, column, dta_version=114, force_strl=False): +def _dtype_to_default_stata_fmt( + dtype, column: Series, dta_version: int = 114, force_strl: bool = False +) -> str: """ Map numpy dtype to stata's default format for this type. Not terribly important since users can change this in Stata. Semantics are @@ -2060,14 +2085,14 @@ class StataWriter(StataParser): def __init__( self, - fname, - data, - convert_dates=None, - write_index=True, - byteorder=None, - time_stamp=None, - data_label=None, - variable_labels=None, + fname: FilePathOrBuffer, + data: DataFrame, + convert_dates: Optional[Dict[Label, str]] = None, + write_index: bool = True, + byteorder: Optional[str] = None, + time_stamp: Optional[datetime.datetime] = None, + data_label: Optional[str] = None, + variable_labels: Optional[Dict[Label, str]] = None, ): super().__init__() self._convert_dates = {} if convert_dates is None else convert_dates @@ -2084,21 +2109,30 @@ def __init__( self._byteorder = _set_endianness(byteorder) self._fname = stringify_path(fname) self.type_converters = {253: np.int32, 252: np.int16, 251: np.int8} - self._converted_names = {} + self._converted_names: Dict[Label, str] = {} + self._file: Optional[BinaryIO] = None - def _write(self, to_write): + def _write(self, to_write: str) -> None: """ Helper to call encode before writing to file for Python 3 compat. """ - self._file.write(to_write.encode(self._encoding or self._default_encoding)) + assert self._file is not None + self._file.write(to_write.encode(self._encoding)) - def _prepare_categoricals(self, data): + def _write_bytes(self, value: bytes) -> None: + """ + Helper to assert file is open before writing. + """ + assert self._file is not None + self._file.write(value) + + def _prepare_categoricals(self, data: DataFrame) -> DataFrame: """Check for categorical columns, retain categorical information for Stata file and convert categorical data to int""" is_cat = [is_categorical_dtype(data[col]) for col in data] self._is_col_cat = is_cat - self._value_labels = [] + self._value_labels: List[StataValueLabel] = [] if not any(is_cat): return data @@ -2133,7 +2167,7 @@ def _prepare_categoricals(self, data): data_formatted.append((col, data[col])) return DataFrame.from_dict(dict(data_formatted)) - def _replace_nans(self, data): + def _replace_nans(self, data: DataFrame) -> DataFrame: # return data """Checks floating point data columns for nans, and replaces these with the generic Stata for missing value (.)""" @@ -2148,11 +2182,11 @@ def _replace_nans(self, data): return data - def _update_strl_names(self): + def _update_strl_names(self) -> None: """No-op, forward compatibility""" pass - def _validate_variable_name(self, name): + def _validate_variable_name(self, name: str) -> str: """ Validate variable names for Stata export. @@ -2182,7 +2216,7 @@ def _validate_variable_name(self, name): name = name.replace(c, "_") return name - def _check_column_names(self, data): + def _check_column_names(self, data: DataFrame) -> DataFrame: """ Checks column names to ensure that they are valid Stata column names. This includes checks for: @@ -2212,7 +2246,7 @@ def _check_column_names(self, data): name = "_" + name # Variable name may not start with a number - if name[0] >= "0" and name[0] <= "9": + if "0" <= name[0] <= "9": name = "_" + name name = name[: min(len(name), 32)] @@ -2256,21 +2290,23 @@ def _check_column_names(self, data): return data - def _set_formats_and_types(self, dtypes): - self.typlist = [] - self.fmtlist = [] + def _set_formats_and_types(self, dtypes: Series) -> None: + self.fmtlist: List[str] = [] + self.typlist: List[int] = [] for col, dtype in dtypes.items(): self.fmtlist.append(_dtype_to_default_stata_fmt(dtype, self.data[col])) self.typlist.append(_dtype_to_stata_type(dtype, self.data[col])) - def _prepare_pandas(self, data): + def _prepare_pandas(self, data: DataFrame) -> None: # NOTE: we might need a different API / class for pandas objects so # we can set different semantics - handle this with a PR to pandas.io data = data.copy() if self._write_index: - data = data.reset_index() + temp = data.reset_index() + if isinstance(temp, DataFrame): + data = temp # Ensure column names are strings data = self._check_column_names(data) @@ -2313,9 +2349,10 @@ def _prepare_pandas(self, data): # set the given format for the datetime cols if self._convert_dates is not None: for key in self._convert_dates: - self.fmtlist[key] = self._convert_dates[key] + if isinstance(key, int): + self.fmtlist[key] = self._convert_dates[key] - def _encode_strings(self): + def _encode_strings(self) -> None: """ Encode strings in dta-specific encoding @@ -2334,7 +2371,7 @@ def _encode_strings(self): dtype = column.dtype if dtype.type == np.object_: inferred_dtype = infer_dtype(column, skipna=True) - if not ((inferred_dtype in ("string")) or len(column) == 0): + if not ((inferred_dtype == "string") or len(column) == 0): col = column.name raise ValueError( f"""\ @@ -2352,7 +2389,7 @@ def _encode_strings(self): ): self.data[col] = encoded - def write_file(self): + def write_file(self) -> None: self._file, self._own_file = _open_file_binary_write(self._fname) try: self._write_header(data_label=self._data_label, time_stamp=self._time_stamp) @@ -2365,8 +2402,8 @@ def write_file(self): self._write_variable_labels() self._write_expansion_fields() self._write_characteristics() - self._prepare_data() - self._write_data() + records = self._prepare_data() + self._write_data(records) self._write_strls() self._write_value_labels() self._write_file_close_tag() @@ -2375,7 +2412,8 @@ def write_file(self): self._close() if self._own_file: try: - os.unlink(self._fname) + if isinstance(self._fname, (str, Path)): + os.unlink(self._fname) except OSError: warnings.warn( f"This save was not successful but {self._fname} could not " @@ -2386,7 +2424,7 @@ def write_file(self): else: self._close() - def _close(self): + def _close(self) -> None: """ Close the file if it was created by the writer. @@ -2396,6 +2434,7 @@ def _close(self): (if supported) """ # Some file-like objects might not support flush + assert self._file is not None try: self._file.flush() except AttributeError: @@ -2403,34 +2442,38 @@ def _close(self): if self._own_file: self._file.close() - def _write_map(self): + def _write_map(self) -> None: """No-op, future compatibility""" pass - def _write_file_close_tag(self): + def _write_file_close_tag(self) -> None: """No-op, future compatibility""" pass - def _write_characteristics(self): + def _write_characteristics(self) -> None: """No-op, future compatibility""" pass - def _write_strls(self): + def _write_strls(self) -> None: """No-op, future compatibility""" pass - def _write_expansion_fields(self): + def _write_expansion_fields(self) -> None: """Write 5 zeros for expansion fields""" self._write(_pad_bytes("", 5)) - def _write_value_labels(self): + def _write_value_labels(self) -> None: for vl in self._value_labels: - self._file.write(vl.generate_value_label(self._byteorder)) + self._write_bytes(vl.generate_value_label(self._byteorder)) - def _write_header(self, data_label=None, time_stamp=None): + def _write_header( + self, + data_label: Optional[str] = None, + time_stamp: Optional[datetime.datetime] = None, + ) -> None: byteorder = self._byteorder # ds_format - just use 114 - self._file.write(struct.pack("b", 114)) + self._write_bytes(struct.pack("b", 114)) # byteorder self._write(byteorder == ">" and "\x01" or "\x02") # filetype @@ -2438,14 +2481,16 @@ def _write_header(self, data_label=None, time_stamp=None): # unused self._write("\x00") # number of vars, 2 bytes - self._file.write(struct.pack(byteorder + "h", self.nvar)[:2]) + self._write_bytes(struct.pack(byteorder + "h", self.nvar)[:2]) # number of obs, 4 bytes - self._file.write(struct.pack(byteorder + "i", self.nobs)[:4]) + self._write_bytes(struct.pack(byteorder + "i", self.nobs)[:4]) # data label 81 bytes, char, null terminated if data_label is None: - self._file.write(self._null_terminate(_pad_bytes("", 80))) + self._write_bytes(self._null_terminate_bytes(_pad_bytes("", 80))) else: - self._file.write(self._null_terminate(_pad_bytes(data_label[:80], 80))) + self._write_bytes( + self._null_terminate_bytes(_pad_bytes(data_label[:80], 80)) + ) # time stamp, 18 bytes, char, null terminated # format dd Mon yyyy hh:mm if time_stamp is None: @@ -2474,43 +2519,43 @@ def _write_header(self, data_label=None, time_stamp=None): + month_lookup[time_stamp.month] + time_stamp.strftime(" %Y %H:%M") ) - self._file.write(self._null_terminate(ts)) + self._write_bytes(self._null_terminate_bytes(ts)) - def _write_variable_types(self): + def _write_variable_types(self) -> None: for typ in self.typlist: - self._file.write(struct.pack("B", typ)) + self._write_bytes(struct.pack("B", typ)) - def _write_varnames(self): + def _write_varnames(self) -> None: # varlist names are checked by _check_column_names # varlist, requires null terminated for name in self.varlist: - name = self._null_terminate(name, True) + name = self._null_terminate_str(name) name = _pad_bytes(name[:32], 33) self._write(name) - def _write_sortlist(self): + def _write_sortlist(self) -> None: # srtlist, 2*(nvar+1), int array, encoded by byteorder srtlist = _pad_bytes("", 2 * (self.nvar + 1)) self._write(srtlist) - def _write_formats(self): + def _write_formats(self) -> None: # fmtlist, 49*nvar, char array for fmt in self.fmtlist: self._write(_pad_bytes(fmt, 49)) - def _write_value_label_names(self): + def _write_value_label_names(self) -> None: # lbllist, 33*nvar, char array for i in range(self.nvar): # Use variable name when categorical if self._is_col_cat[i]: name = self.varlist[i] - name = self._null_terminate(name, True) + name = self._null_terminate_str(name) name = _pad_bytes(name[:32], 33) self._write(name) else: # Default is empty label self._write(_pad_bytes("", 33)) - def _write_variable_labels(self): + def _write_variable_labels(self) -> None: # Missing labels are 80 blank characters plus null termination blank = _pad_bytes("", 81) @@ -2534,11 +2579,11 @@ def _write_variable_labels(self): else: self._write(blank) - def _convert_strls(self, data): + def _convert_strls(self, data: DataFrame) -> DataFrame: """No-op, future compatibility""" return data - def _prepare_data(self): + def _prepare_data(self) -> np.recarray: data = self.data typlist = self.typlist convert_dates = self._convert_dates @@ -2568,23 +2613,21 @@ def _prepare_data(self): dtype = dtype.newbyteorder(self._byteorder) dtypes[col] = dtype - self.data = data.to_records(index=False, column_dtypes=dtypes) - - def _write_data(self): - data = self.data - self._file.write(data.tobytes()) + return data.to_records(index=False, column_dtypes=dtypes) - def _null_terminate(self, s, as_string=False): - null_byte = "\x00" - s += null_byte - - if not as_string: - s = s.encode(self._encoding) + def _write_data(self, records: np.recarray) -> None: + self._write_bytes(records.tobytes()) + @staticmethod + def _null_terminate_str(s: str) -> str: + s += "\x00" return s + def _null_terminate_bytes(self, s: str) -> bytes: + return self._null_terminate_str(s).encode(self._encoding) + -def _dtype_to_stata_type_117(dtype, column, force_strl): +def _dtype_to_stata_type_117(dtype: np.dtype, column: Series, force_strl: bool) -> int: """ Converts dtype types to stata types. Returns the byte of the given ordinal. See TYPE_MAP and comments for an explanation. This is also explained in @@ -2626,7 +2669,7 @@ def _dtype_to_stata_type_117(dtype, column, force_strl): raise NotImplementedError(f"Data type {dtype} not supported.") -def _pad_bytes_new(name, length): +def _pad_bytes_new(name: Union[str, bytes], length: int) -> bytes: """ Takes a bytes instance and pads it with null bytes until it's length chars. """ @@ -2646,7 +2689,7 @@ class StataStrLWriter: ---------- df : DataFrame DataFrame to convert - columns : list + columns : Sequence[str] List of columns names to convert to StrL version : int, optional dta version. Currently supports 117, 118 and 119 @@ -2664,7 +2707,13 @@ class StataStrLWriter: characters. """ - def __init__(self, df, columns, version=117, byteorder=None): + def __init__( + self, + df: DataFrame, + columns: Sequence[str], + version: int = 117, + byteorder: Optional[str] = None, + ): if version not in (117, 118, 119): raise ValueError("Only dta versions 117, 118 and 119 supported") self._dta_ver = version @@ -2691,11 +2740,11 @@ def __init__(self, df, columns, version=117, byteorder=None): self._gso_o_type = gso_o_type self._gso_v_type = gso_v_type - def _convert_key(self, key): + def _convert_key(self, key: Tuple[int, int]) -> int: v, o = key return v + self._o_offet * o - def generate_table(self): + def generate_table(self) -> Tuple[Dict[str, Tuple[int, int]], DataFrame]: """ Generates the GSO lookup table for the DataFrame @@ -2747,7 +2796,7 @@ def generate_table(self): return gso_table, gso_df - def generate_blob(self, gso_table): + def generate_blob(self, gso_table: Dict[str, Tuple[int, int]]) -> bytes: """ Generates the binary blob of GSOs that is written to the dta file. @@ -2890,18 +2939,20 @@ class StataWriter117(StataWriter): def __init__( self, - fname, - data, - convert_dates=None, - write_index=True, - byteorder=None, - time_stamp=None, - data_label=None, - variable_labels=None, - convert_strl=None, + fname: FilePathOrBuffer, + data: DataFrame, + convert_dates: Optional[Dict[Label, str]] = None, + write_index: bool = True, + byteorder: Optional[str] = None, + time_stamp: Optional[datetime.datetime] = None, + data_label: Optional[str] = None, + variable_labels: Optional[Dict[Label, str]] = None, + convert_strl: Optional[Sequence[Label]] = None, ): - # Shallow copy since convert_strl might be modified later - self._convert_strl = [] if convert_strl is None else convert_strl[:] + # Copy to new list since convert_strl might be modified later + self._convert_strl: List[Label] = [] + if convert_strl is not None: + self._convert_strl.extend(convert_strl) super().__init__( fname, @@ -2913,24 +2964,29 @@ def __init__( data_label=data_label, variable_labels=variable_labels, ) - self._map = None - self._strl_blob = None + self._map: Dict[str, int] = {} + self._strl_blob = b"" @staticmethod - def _tag(val, tag): + def _tag(val: Union[str, bytes], tag: str) -> bytes: """Surround val with """ if isinstance(val, str): val = bytes(val, "utf-8") return bytes("<" + tag + ">", "utf-8") + val + bytes("", "utf-8") - def _update_map(self, tag): + def _update_map(self, tag: str) -> None: """Update map location for tag with file position""" + assert self._file is not None self._map[tag] = self._file.tell() - def _write_header(self, data_label=None, time_stamp=None): + def _write_header( + self, + data_label: Optional[str] = None, + time_stamp: Optional[datetime.datetime] = None, + ) -> None: """Write the file header""" byteorder = self._byteorder - self._file.write(bytes("", "utf-8")) + self._write_bytes(bytes("", "utf-8")) bio = BytesIO() # ds_format - 117 bio.write(self._tag(bytes(str(self._dta_version), "utf-8"), "release")) @@ -2944,11 +3000,11 @@ def _write_header(self, data_label=None, time_stamp=None): bio.write(self._tag(struct.pack(byteorder + nobs_size, self.nobs), "N")) # data label 81 bytes, char, null terminated label = data_label[:80] if data_label is not None else "" - label = label.encode(self._encoding) + encoded_label = label.encode(self._encoding) label_size = "B" if self._dta_version == 117 else "H" - label_len = struct.pack(byteorder + label_size, len(label)) - label = label_len + label - bio.write(self._tag(label, "label")) + label_len = struct.pack(byteorder + label_size, len(encoded_label)) + encoded_label = label_len + encoded_label + bio.write(self._tag(encoded_label, "label")) # time stamp, 18 bytes, char, null terminated # format dd Mon yyyy hh:mm if time_stamp is None: @@ -2977,16 +3033,17 @@ def _write_header(self, data_label=None, time_stamp=None): + time_stamp.strftime(" %Y %H:%M") ) # '\x11' added due to inspection of Stata file - ts = b"\x11" + bytes(ts, "utf-8") - bio.write(self._tag(ts, "timestamp")) + stata_ts = b"\x11" + bytes(ts, "utf-8") + bio.write(self._tag(stata_ts, "timestamp")) bio.seek(0) - self._file.write(self._tag(bio.read(), "header")) + self._write_bytes(self._tag(bio.read(), "header")) - def _write_map(self): + def _write_map(self) -> None: """Called twice during file write. The first populates the values in the map with 0s. The second call writes the final map locations when all blocks have been written.""" - if self._map is None: + assert self._file is not None + if not self._map: self._map = dict( ( ("stata_data", 0), @@ -3011,43 +3068,43 @@ def _write_map(self): for val in self._map.values(): bio.write(struct.pack(self._byteorder + "Q", val)) bio.seek(0) - self._file.write(self._tag(bio.read(), "map")) + self._write_bytes(self._tag(bio.read(), "map")) - def _write_variable_types(self): + def _write_variable_types(self) -> None: self._update_map("variable_types") bio = BytesIO() for typ in self.typlist: bio.write(struct.pack(self._byteorder + "H", typ)) bio.seek(0) - self._file.write(self._tag(bio.read(), "variable_types")) + self._write_bytes(self._tag(bio.read(), "variable_types")) - def _write_varnames(self): + def _write_varnames(self) -> None: self._update_map("varnames") bio = BytesIO() # 118 scales by 4 to accommodate utf-8 data worst case encoding vn_len = 32 if self._dta_version == 117 else 128 for name in self.varlist: - name = self._null_terminate(name, True) + name = self._null_terminate_str(name) name = _pad_bytes_new(name[:32].encode(self._encoding), vn_len + 1) bio.write(name) bio.seek(0) - self._file.write(self._tag(bio.read(), "varnames")) + self._write_bytes(self._tag(bio.read(), "varnames")) - def _write_sortlist(self): + def _write_sortlist(self) -> None: self._update_map("sortlist") sort_size = 2 if self._dta_version < 119 else 4 - self._file.write(self._tag(b"\x00" * sort_size * (self.nvar + 1), "sortlist")) + self._write_bytes(self._tag(b"\x00" * sort_size * (self.nvar + 1), "sortlist")) - def _write_formats(self): + def _write_formats(self) -> None: self._update_map("formats") bio = BytesIO() fmt_len = 49 if self._dta_version == 117 else 57 for fmt in self.fmtlist: bio.write(_pad_bytes_new(fmt.encode(self._encoding), fmt_len)) bio.seek(0) - self._file.write(self._tag(bio.read(), "formats")) + self._write_bytes(self._tag(bio.read(), "formats")) - def _write_value_label_names(self): + def _write_value_label_names(self) -> None: self._update_map("value_label_names") bio = BytesIO() # 118 scales by 4 to accommodate utf-8 data worst case encoding @@ -3057,13 +3114,13 @@ def _write_value_label_names(self): name = "" # default name if self._is_col_cat[i]: name = self.varlist[i] - name = self._null_terminate(name, True) - name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1) - bio.write(name) + name = self._null_terminate_str(name) + encoded_name = _pad_bytes_new(name[:32].encode(self._encoding), vl_len + 1) + bio.write(encoded_name) bio.seek(0) - self._file.write(self._tag(bio.read(), "value_label_names")) + self._write_bytes(self._tag(bio.read(), "value_label_names")) - def _write_variable_labels(self): + def _write_variable_labels(self) -> None: # Missing labels are 80 blank characters plus null termination self._update_map("variable_labels") bio = BytesIO() @@ -3075,7 +3132,7 @@ def _write_variable_labels(self): for _ in range(self.nvar): bio.write(blank) bio.seek(0) - self._file.write(self._tag(bio.read(), "variable_labels")) + self._write_bytes(self._tag(bio.read(), "variable_labels")) return for col in self.data: @@ -3095,31 +3152,27 @@ def _write_variable_labels(self): else: bio.write(blank) bio.seek(0) - self._file.write(self._tag(bio.read(), "variable_labels")) + self._write_bytes(self._tag(bio.read(), "variable_labels")) - def _write_characteristics(self): + def _write_characteristics(self) -> None: self._update_map("characteristics") - self._file.write(self._tag(b"", "characteristics")) + self._write_bytes(self._tag(b"", "characteristics")) - def _write_data(self): + def _write_data(self, records) -> None: self._update_map("data") - data = self.data - self._file.write(b"") - self._file.write(data.tobytes()) - self._file.write(b"") + self._write_bytes(b"") + self._write_bytes(records.tobytes()) + self._write_bytes(b"") - def _write_strls(self): + def _write_strls(self) -> None: self._update_map("strls") - strls = b"" - if self._strl_blob is not None: - strls = self._strl_blob - self._file.write(self._tag(strls, "strls")) + self._write_bytes(self._tag(self._strl_blob, "strls")) - def _write_expansion_fields(self): + def _write_expansion_fields(self) -> None: """No-op in dta 117+""" pass - def _write_value_labels(self): + def _write_value_labels(self) -> None: self._update_map("value_labels") bio = BytesIO() for vl in self._value_labels: @@ -3127,14 +3180,14 @@ def _write_value_labels(self): lab = self._tag(lab, "lbl") bio.write(lab) bio.seek(0) - self._file.write(self._tag(bio.read(), "value_labels")) + self._write_bytes(self._tag(bio.read(), "value_labels")) - def _write_file_close_tag(self): + def _write_file_close_tag(self) -> None: self._update_map("stata_data_close") - self._file.write(bytes("", "utf-8")) + self._write_bytes(bytes("", "utf-8")) self._update_map("end-of-file") - def _update_strl_names(self): + def _update_strl_names(self) -> None: """Update column names for conversion to strl if they might have been changed to comply with Stata naming rules""" # Update convert_strl if names changed @@ -3143,7 +3196,7 @@ def _update_strl_names(self): idx = self._convert_strl.index(orig) self._convert_strl[idx] = new - def _convert_strls(self, data): + def _convert_strls(self, data: DataFrame) -> DataFrame: """Convert columns to StrLs if either very large or in the convert_strl variable""" convert_cols = [ @@ -3159,7 +3212,7 @@ def _convert_strls(self, data): self._strl_blob = ssw.generate_blob(tab) return data - def _set_formats_and_types(self, dtypes): + def _set_formats_and_types(self, dtypes: Series) -> None: self.typlist = [] self.fmtlist = [] for col, dtype in dtypes.items(): @@ -3266,13 +3319,13 @@ def __init__( self, fname: FilePathOrBuffer, data: DataFrame, - convert_dates: Optional[Dict[Hashable, str]] = None, + convert_dates: Optional[Dict[Label, str]] = None, write_index: bool = True, byteorder: Optional[str] = None, time_stamp: Optional[datetime.datetime] = None, data_label: Optional[str] = None, - variable_labels: Optional[Dict[Hashable, str]] = None, - convert_strl: Optional[Sequence[Hashable]] = None, + variable_labels: Optional[Dict[Label, str]] = None, + convert_strl: Optional[Sequence[Label]] = None, version: Optional[int] = None, ): if version is None: