diff --git a/doc/source/whatsnew/v0.22.0.txt b/doc/source/whatsnew/v0.22.0.txt index 8afdd1b2e22b3..4211d9913a497 100644 --- a/doc/source/whatsnew/v0.22.0.txt +++ b/doc/source/whatsnew/v0.22.0.txt @@ -90,6 +90,7 @@ Documentation Changes Bug Fixes ~~~~~~~~~ +- Bug in ``DataFrame.to_pickle()`` fails for .zip format (:issue:`17778`) Conversion ^^^^^^^^^^ diff --git a/pandas/io/common.py b/pandas/io/common.py index 534c1e0671150..f799cab161cd9 100644 --- a/pandas/io/common.py +++ b/pandas/io/common.py @@ -357,17 +357,20 @@ def _get_handle(path_or_buf, mode, encoding=None, compression=None, # ZIP Compression elif compression == 'zip': import zipfile - zip_file = zipfile.ZipFile(path_or_buf) - zip_names = zip_file.namelist() - if len(zip_names) == 1: - f = zip_file.open(zip_names.pop()) - elif len(zip_names) == 0: - raise ValueError('Zero files found in ZIP file {}' - .format(path_or_buf)) + if mode == 'wb': + f = zipfile.ZipFile(path_or_buf, 'w') else: - raise ValueError('Multiple files found in ZIP file.' - ' Only one file per ZIP: {}' - .format(zip_names)) + zip_file = zipfile.ZipFile(path_or_buf) + zip_names = zip_file.namelist() + if len(zip_names) == 1: + f = zip_file.open(zip_names.pop()) + elif len(zip_names) == 0: + raise ValueError('Zero files found in ZIP file {}' + .format(path_or_buf)) + else: + raise ValueError('Multiple files found in ZIP file.' + ' Only one file per ZIP: {}' + .format(zip_names)) # XZ Compression elif compression == 'xz': diff --git a/pandas/io/pickle.py b/pandas/io/pickle.py index 143b76575e36b..aab9ffa1cce45 100644 --- a/pandas/io/pickle.py +++ b/pandas/io/pickle.py @@ -42,7 +42,17 @@ def to_pickle(obj, path, compression='infer', protocol=pkl.HIGHEST_PROTOCOL): if protocol < 0: protocol = pkl.HIGHEST_PROTOCOL try: - pkl.dump(obj, f, protocol=protocol) + import zipfile + if isinstance(f, zipfile.ZipFile): + import os + import tempfile + tmp_file = tempfile.NamedTemporaryFile(delete=False) + pkl.dump(obj, tmp_file, protocol=protocol) + tmp_file.close() + f.write(tmp_file.name) + os.remove(tmp_file.name) + else: + pkl.dump(obj, f, protocol=protocol) finally: for _f in fh: _f.close() diff --git a/pandas/tests/io/test_pickle.py b/pandas/tests/io/test_pickle.py index 91c1f19f5caab..91b59b2ff3ffb 100644 --- a/pandas/tests/io/test_pickle.py +++ b/pandas/tests/io/test_pickle.py @@ -382,7 +382,7 @@ def decompress_file(self, src_path, dest_path, compression): fh.write(f.read()) f.close() - @pytest.mark.parametrize('compression', [None, 'gzip', 'bz2', 'xz']) + @pytest.mark.parametrize('compression', [None, 'gzip', 'zip', 'bz2', 'xz']) def test_write_explicit(self, compression, get_random_path): # issue 11666 if compression == 'xz': @@ -414,7 +414,8 @@ def test_write_explicit_bad(self, compression, get_random_path): df = tm.makeDataFrame() df.to_pickle(path, compression=compression) - @pytest.mark.parametrize('ext', ['', '.gz', '.bz2', '.xz', '.no_compress']) + @pytest.mark.parametrize('ext', ['', '.gz', '.zip', '.bz2', '.xz', + '.no_compress']) def test_write_infer(self, ext, get_random_path): if ext == '.xz': tm._skip_if_no_lzma() @@ -442,7 +443,7 @@ def test_write_infer(self, ext, get_random_path): tm.assert_frame_equal(df, df2) - @pytest.mark.parametrize('compression', [None, 'gzip', 'bz2', 'xz', "zip"]) + @pytest.mark.parametrize('compression', [None, 'gzip', 'bz2', 'xz', 'zip']) def test_read_explicit(self, compression, get_random_path): # issue 11666 if compression == 'xz':