Skip to content

Commit

Permalink
BUG: GH17778 - DataFrame.to_pickle() fails for .zip format.
Browse files Browse the repository at this point in the history
GH17778: add 'zip' format to unittests.
Added entry in doc/source/whatsnew/v0.22.0.txt file to Bug Fixes section.
  • Loading branch information
Krzysztof Chomski authored and kchomski committed Nov 14, 2017
1 parent 63e8527 commit 53593d0
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 14 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.22.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Documentation Changes
Bug Fixes
~~~~~~~~~

- Bug in ``DataFrame.to_pickle()`` fails for .zip format (:issue:`17778`)

Conversion
^^^^^^^^^^
Expand Down
23 changes: 13 additions & 10 deletions pandas/io/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,17 +357,20 @@ def _get_handle(path_or_buf, mode, encoding=None, compression=None,
# ZIP Compression
elif compression == 'zip':
import zipfile
zip_file = zipfile.ZipFile(path_or_buf)
zip_names = zip_file.namelist()
if len(zip_names) == 1:
f = zip_file.open(zip_names.pop())
elif len(zip_names) == 0:
raise ValueError('Zero files found in ZIP file {}'
.format(path_or_buf))
if mode == 'wb':
f = zipfile.ZipFile(path_or_buf, 'w')
else:
raise ValueError('Multiple files found in ZIP file.'
' Only one file per ZIP: {}'
.format(zip_names))
zip_file = zipfile.ZipFile(path_or_buf)
zip_names = zip_file.namelist()
if len(zip_names) == 1:
f = zip_file.open(zip_names.pop())
elif len(zip_names) == 0:
raise ValueError('Zero files found in ZIP file {}'
.format(path_or_buf))
else:
raise ValueError('Multiple files found in ZIP file.'
' Only one file per ZIP: {}'
.format(zip_names))

# XZ Compression
elif compression == 'xz':
Expand Down
12 changes: 11 additions & 1 deletion pandas/io/pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,17 @@ def to_pickle(obj, path, compression='infer', protocol=pkl.HIGHEST_PROTOCOL):
if protocol < 0:
protocol = pkl.HIGHEST_PROTOCOL
try:
pkl.dump(obj, f, protocol=protocol)
import zipfile
if isinstance(f, zipfile.ZipFile):
import os
import tempfile
tmp_file = tempfile.NamedTemporaryFile(delete=False)
pkl.dump(obj, tmp_file, protocol=protocol)
tmp_file.close()
f.write(tmp_file.name)
os.remove(tmp_file.name)
else:
pkl.dump(obj, f, protocol=protocol)
finally:
for _f in fh:
_f.close()
Expand Down
7 changes: 4 additions & 3 deletions pandas/tests/io/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,7 @@ def decompress_file(self, src_path, dest_path, compression):
fh.write(f.read())
f.close()

@pytest.mark.parametrize('compression', [None, 'gzip', 'bz2', 'xz'])
@pytest.mark.parametrize('compression', [None, 'gzip', 'zip', 'bz2', 'xz'])
def test_write_explicit(self, compression, get_random_path):
# issue 11666
if compression == 'xz':
Expand Down Expand Up @@ -414,7 +414,8 @@ def test_write_explicit_bad(self, compression, get_random_path):
df = tm.makeDataFrame()
df.to_pickle(path, compression=compression)

@pytest.mark.parametrize('ext', ['', '.gz', '.bz2', '.xz', '.no_compress'])
@pytest.mark.parametrize('ext', ['', '.gz', '.zip', '.bz2', '.xz',
'.no_compress'])
def test_write_infer(self, ext, get_random_path):
if ext == '.xz':
tm._skip_if_no_lzma()
Expand Down Expand Up @@ -442,7 +443,7 @@ def test_write_infer(self, ext, get_random_path):

tm.assert_frame_equal(df, df2)

@pytest.mark.parametrize('compression', [None, 'gzip', 'bz2', 'xz', "zip"])
@pytest.mark.parametrize('compression', [None, 'gzip', 'bz2', 'xz', 'zip'])
def test_read_explicit(self, compression, get_random_path):
# issue 11666
if compression == 'xz':
Expand Down

0 comments on commit 53593d0

Please sign in to comment.