diff --git a/doc/source/v0.15.0.txt b/doc/source/v0.15.0.txt index d608304511a08..6c58e751a6bcc 100644 --- a/doc/source/v0.15.0.txt +++ b/doc/source/v0.15.0.txt @@ -482,6 +482,8 @@ Performance - Performance improvements in ``Period`` creation (and ``PeriodIndex`` setitem) (:issue:`5155`) - Improvements in Series.transform for significant performance gains (revised) (:issue:`6496`) - Performance improvements in ``StataReader`` when reading large files (:issue:`8040`, :issue:`8073`) +- Performance improvements in ``StataWriter`` when writing large files (:issue:`8079`) + diff --git a/pandas/io/stata.py b/pandas/io/stata.py index 0cf57d3035db5..246465153c611 100644 --- a/pandas/io/stata.py +++ b/pandas/io/stata.py @@ -19,11 +19,12 @@ from pandas.core.series import Series from pandas.core.categorical import Categorical import datetime -from pandas import compat, to_timedelta, to_datetime +from pandas import compat, to_timedelta, to_datetime, isnull, DatetimeIndex from pandas.compat import lrange, lmap, lzip, text_type, string_types, range, \ zip +import pandas.core.common as com from pandas.io.common import get_filepath_or_buffer -from pandas.lib import max_len_string_array, is_string_array +from pandas.lib import max_len_string_array, infer_dtype from pandas.tslib import NaT, Timestamp def read_stata(filepath_or_buffer, convert_dates=True, @@ -63,88 +64,6 @@ def read_stata(filepath_or_buffer, convert_dates=True, stata_epoch = datetime.datetime(1960, 1, 1) -def _stata_elapsed_date_to_datetime(date, fmt): - """ - Convert from SIF to datetime. http://www.stata.com/help.cgi?datetime - - Parameters - ---------- - date : int - The Stata Internal Format date to convert to datetime according to fmt - fmt : str - The format to convert to. Can be, tc, td, tw, tm, tq, th, ty - - Examples - -------- - >>> _stata_elapsed_date_to_datetime(52, "%tw") - datetime.datetime(1961, 1, 1, 0, 0) - - Notes - ----- - datetime/c - tc - milliseconds since 01jan1960 00:00:00.000, assuming 86,400 s/day - datetime/C - tC - NOT IMPLEMENTED - milliseconds since 01jan1960 00:00:00.000, adjusted for leap seconds - date - td - days since 01jan1960 (01jan1960 = 0) - weekly date - tw - weeks since 1960w1 - This assumes 52 weeks in a year, then adds 7 * remainder of the weeks. - The datetime value is the start of the week in terms of days in the - year, not ISO calendar weeks. - monthly date - tm - months since 1960m1 - quarterly date - tq - quarters since 1960q1 - half-yearly date - th - half-years since 1960h1 yearly - date - ty - years since 0000 - - If you don't have pandas with datetime support, then you can't do - milliseconds accurately. - """ - #NOTE: we could run into overflow / loss of precision situations here - # casting to int, but I'm not sure what to do. datetime won't deal with - # numpy types and numpy datetime isn't mature enough / we can't rely on - # pandas version > 0.7.1 - #TODO: IIRC relative delta doesn't play well with np.datetime? - #TODO: When pandas supports more than datetime64[ns], this should be improved to use correct range, e.g. datetime[Y] for yearly - if np.isnan(date): - return NaT - date = int(date) - if fmt in ["%tc", "tc"]: - from dateutil.relativedelta import relativedelta - return stata_epoch + relativedelta(microseconds=date * 1000) - elif fmt in ["%tC", "tC"]: - from warnings import warn - warn("Encountered %tC format. Leaving in Stata Internal Format.") - return date - elif fmt in ["%td", "td", "%d", "d"]: - return stata_epoch + datetime.timedelta(int(date)) - elif fmt in ["%tw", "tw"]: # does not count leap days - 7 days is a week - year = datetime.datetime(stata_epoch.year + date // 52, 1, 1) - day_delta = (date % 52) * 7 - return year + datetime.timedelta(int(day_delta)) - elif fmt in ["%tm", "tm"]: - year = stata_epoch.year + date // 12 - month_delta = (date % 12) + 1 - return datetime.datetime(year, month_delta, 1) - elif fmt in ["%tq", "tq"]: - year = stata_epoch.year + date // 4 - month_delta = (date % 4) * 3 + 1 - return datetime.datetime(year, month_delta, 1) - elif fmt in ["%th", "th"]: - year = stata_epoch.year + date // 2 - month_delta = (date % 2) * 6 + 1 - return datetime.datetime(year, month_delta, 1) - elif fmt in ["%ty", "ty"]: - if date > 0: - return datetime.datetime(date, 1, 1) - else: # don't do negative years bc can't mix dtypes in column - raise ValueError("Year 0 and before not implemented") - else: - raise ValueError("Date fmt %s not understood" % fmt) def _stata_elapsed_date_to_datetime_vec(dates, fmt): @@ -153,7 +72,7 @@ def _stata_elapsed_date_to_datetime_vec(dates, fmt): Parameters ---------- - dates : array-like + dates : Series The Stata Internal Format date to convert to datetime according to fmt fmt : str The format to convert to. Can be, tc, td, tw, tm, tq, th, ty @@ -166,8 +85,11 @@ def _stata_elapsed_date_to_datetime_vec(dates, fmt): Examples -------- - >>> _stata_elapsed_date_to_datetime(52, "%tw") - datetime.datetime(1961, 1, 1, 0, 0) + >>> import pandas as pd + >>> dates = pd.Series([52]) + >>> _stata_elapsed_date_to_datetime_vec(dates , "%tw") + 0 1961-01-01 + dtype: datetime64[ns] Notes ----- @@ -288,7 +210,6 @@ def convert_delta_safe(base, deltas, unit): month = (dates % 2) * 6 + 1 conv_dates = convert_year_month_safe(year, month) elif fmt in ["%ty", "ty"]: # Years -- not delta - # TODO: Check about negative years, here, and raise or warn if needed year = dates month = np.ones_like(dates) conv_dates = convert_year_month_safe(year, month) @@ -299,49 +220,103 @@ def convert_delta_safe(base, deltas, unit): conv_dates[bad_locs] = NaT return conv_dates -def _datetime_to_stata_elapsed(date, fmt): + +def _datetime_to_stata_elapsed_vec(dates, fmt): """ Convert from datetime to SIF. http://www.stata.com/help.cgi?datetime Parameters ---------- - date : datetime.datetime - The date to convert to the Stata Internal Format given by fmt + dates : Series + Series or array containing datetime.datetime or datetime64[ns] to + convert to the Stata Internal Format given by fmt fmt : str The format to convert to. Can be, tc, td, tw, tm, tq, th, ty """ - if not isinstance(date, datetime.datetime): - raise ValueError("date should be datetime.datetime format") - stata_epoch = datetime.datetime(1960, 1, 1) - # Handle NaTs - if date is NaT: - # Missing value for dates ('.'), assumed always double - # TODO: Should be moved so a const somewhere, and consolidated - return struct.unpack(' 6) + d = parse_dates_safe(dates, year=True) + conv_dates = 2 * (d.year - stata_epoch.year) + \ + (d.month > 6).astype(np.int) elif fmt in ["%ty", "ty"]: - return date.year + d = parse_dates_safe(dates, year=True) + conv_dates = d.year else: raise ValueError("fmt %s not understood" % fmt) + conv_dates = Series(conv_dates, dtype=np.float64) + missing_value = struct.unpack('= 2 * 53: + if data[col].max() >= 2 ** 53: ws = precision_loss_doc % ('uint64', 'float64') data[col] = data[col].astype(dtype) @@ -1254,9 +1229,8 @@ def _dtype_to_default_stata_fmt(dtype, column): Maps numpy dtype to stata's default format for this type. Not terribly important since users can change this in Stata. Semantics are - string -> "%DDs" where DD is the length of the string - object -> "%DDs" where DD is the length of the string, if a string, or 244 - for anything that cannot be converted to a string. + object -> "%DDs" where DD is the length of the string. If not a string, + raise ValueError float64 -> "%10.0g" float32 -> "%9.0g" int64 -> "%9.0g" @@ -1264,19 +1238,13 @@ def _dtype_to_default_stata_fmt(dtype, column): int16 -> "%8.0g" int8 -> "%8.0g" """ - #TODO: expand this to handle a default datetime format? - if dtype.type == np.string_: - if max_len_string_array(column.values) > 244: - raise ValueError(excessive_string_length_error % column.name) - - return "%" + str(dtype.itemsize) + "s" - elif dtype.type == np.object_: - try: - # Try to use optimal size if available - itemsize = max_len_string_array(column.values) - except: - # Default size - itemsize = 244 + # TODO: expand this to handle a default datetime format? + if dtype.type == np.object_: + inferred_dtype = infer_dtype(column.dropna()) + if not (inferred_dtype in ('string', 'unicode') + or len(column) == 0): + raise ValueError('Writing general object arrays is not supported') + itemsize = max_len_string_array(column.values) if itemsize > 244: raise ValueError(excessive_string_length_error % column.name) @@ -1328,12 +1296,15 @@ class StataWriter(StataParser): Examples -------- + >>> import pandas as pd + >>> data = pd.DataFrame([[1.0, 1]], columns=['a', 'b']) >>> writer = StataWriter('./data_file.dta', data) >>> writer.write_file() Or with dates - - >>> writer = StataWriter('./date_data_file.dta', date, {2 : 'tw'}) + >>> from datetime import datetime + >>> data = pd.DataFrame([[datetime(2000,1,1)]], columns=['date']) + >>> writer = StataWriter('./date_data_file.dta', data, {'date' : 'tw'}) >>> writer.write_file() """ def __init__(self, fname, data, convert_dates=None, write_index=True, @@ -1502,11 +1473,8 @@ def write_file(self): self._write_variable_labels() # write 5 zeros for expansion fields self._write(_pad_bytes("", 5)) - if self._convert_dates is None: - self._write_data_nodates() - else: - self._write_data_dates() - #self._write_value_labels() + self._prepare_data() + self._write_data() self._file.close() def _write_header(self, data_label=None, time_stamp=None): @@ -1573,59 +1541,46 @@ def _write_variable_labels(self, labels=None): for i in range(nvar): self._write(_pad_bytes("", 81)) - def _write_data_nodates(self): - data = self.datarows - byteorder = self._byteorder - TYPE_MAP = self.TYPE_MAP + def _prepare_data(self): + data = self.data.copy() typlist = self.typlist - for row in data: - #row = row.squeeze().tolist() # needed for structured arrays - for i, var in enumerate(row): - typ = ord(typlist[i]) - if typ <= 244: # we've got a string - if var is None or var == np.nan: - var = _pad_bytes('', typ) - if len(var) < typ: - var = _pad_bytes(var, typ) - if compat.PY3: - self._write(var) - else: - self._write(var.encode(self._encoding)) - else: - try: - self._file.write(struct.pack(byteorder + TYPE_MAP[typ], - var)) - except struct.error: - # have to be strict about type pack won't do any - # kind of casting - self._file.write(struct.pack(byteorder+TYPE_MAP[typ], - self.type_converters[typ](var))) - - def _write_data_dates(self): convert_dates = self._convert_dates - data = self.datarows - byteorder = self._byteorder - TYPE_MAP = self.TYPE_MAP - MISSING_VALUES = self.MISSING_VALUES - typlist = self.typlist - for row in data: - #row = row.squeeze().tolist() # needed for structured arrays - for i, var in enumerate(row): - typ = ord(typlist[i]) - #NOTE: If anyone finds this terribly slow, there is - # a vectorized way to convert dates, see genfromdta for going - # from int to datetime and reverse it. will copy data though + # 1. Convert dates + if self._convert_dates is not None: + for i, col in enumerate(data): if i in convert_dates: - var = _datetime_to_stata_elapsed(var, self.fmtlist[i]) - if typ <= 244: # we've got a string - if len(var) < typ: - var = _pad_bytes(var, typ) - if compat.PY3: - self._write(var) - else: - self._write(var.encode(self._encoding)) - else: - self._file.write(struct.pack(byteorder+TYPE_MAP[typ], var)) + data[col] = _datetime_to_stata_elapsed_vec(data[col], + self.fmtlist[i]) + + # 2. Convert bad string data to '' and pad to correct length + dtype = [] + data_cols = [] + has_strings = False + for i, col in enumerate(data): + typ = ord(typlist[i]) + if typ <= 244: + has_strings = True + data[col] = data[col].fillna('').apply(_pad_bytes, args=(typ,)) + stype = 'S%d' % typ + dtype.append(('c'+str(i), stype)) + string = data[col].str.encode(self._encoding) + data_cols.append(string.values.astype(stype)) + else: + dtype.append(('c'+str(i), data[col].dtype)) + data_cols.append(data[col].values) + dtype = np.dtype(dtype) + + # 3. Convert to record array + + # data.to_records(index=False, convert_datetime64=False) + if has_strings: + self.data = np.fromiter(zip(*data_cols), dtype=dtype) + else: + self.data = data.to_records(index=False) + + def _write_data(self): + data = self.data + data.tofile(self._file) def _null_terminate(self, s, as_string=False): null_byte = '\x00' diff --git a/pandas/io/tests/data/stata9_115.dta b/pandas/io/tests/data/stata9_115.dta index 6c3b6ab4dc686..5ad6cd6a2c8ff 100644 Binary files a/pandas/io/tests/data/stata9_115.dta and b/pandas/io/tests/data/stata9_115.dta differ diff --git a/pandas/io/tests/data/stata9_117.dta b/pandas/io/tests/data/stata9_117.dta index 6c3b6ab4dc686..5ad6cd6a2c8ff 100644 Binary files a/pandas/io/tests/data/stata9_117.dta and b/pandas/io/tests/data/stata9_117.dta differ diff --git a/pandas/io/tests/test_stata.py b/pandas/io/tests/test_stata.py index 54c1dd20029ee..c458688b3d2d2 100644 --- a/pandas/io/tests/test_stata.py +++ b/pandas/io/tests/test_stata.py @@ -646,14 +646,14 @@ def test_missing_value_conversion(self): tm.assert_frame_equal(expected, parsed_117) def test_big_dates(self): - yr = [1960, 2000, 9999, 100] - mo = [1, 1, 12, 1] - dd = [1, 1, 31, 1] - hr = [0, 0, 23, 0] - mm = [0, 0, 59, 0] - ss = [0, 0, 59, 0] + yr = [1960, 2000, 9999, 100, 2262, 1677] + mo = [1, 1, 12, 1, 4, 9] + dd = [1, 1, 31, 1, 22, 23] + hr = [0, 0, 23, 0, 0, 0] + mm = [0, 0, 59, 0, 0, 0] + ss = [0, 0, 59, 0, 0, 0] expected = [] - for i in range(4): + for i in range(len(yr)): row = [] for j in range(7): if j == 0: @@ -672,6 +672,11 @@ def test_big_dates(self): expected[2][3] = datetime(9999,12,1) expected[2][4] = datetime(9999,10,1) expected[2][5] = datetime(9999,7,1) + expected[4][2] = datetime(2262,4,16) + expected[4][3] = expected[4][4] = datetime(2262,4,1) + expected[4][5] = expected[4][6] = datetime(2262,1,1) + expected[5][2] = expected[5][3] = expected[5][4] = datetime(1677,10,1) + expected[5][5] = expected[5][6] = datetime(1678,1,1) expected = DataFrame(expected, columns=columns, dtype=np.object) @@ -679,7 +684,17 @@ def test_big_dates(self): parsed_117 = read_stata(self.dta18_117) tm.assert_frame_equal(expected, parsed_115) tm.assert_frame_equal(expected, parsed_117) - assert True + + date_conversion = dict((c, c[-2:]) for c in columns) + #{c : c[-2:] for c in columns} + with tm.ensure_clean() as path: + expected.index.name = 'index' + expected.to_stata(path, date_conversion) + written_and_read_again = self.read_dta(path) + tm.assert_frame_equal(written_and_read_again.set_index('index'), + expected) + + if __name__ == '__main__':