Skip to content

Commit

Permalink
BUG: Properly validate and parse nrows in read_csv
Browse files Browse the repository at this point in the history
  • Loading branch information
gfyoung committed May 25, 2016
1 parent e0a2e3b commit d856051
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 8 deletions.
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.18.2.txt
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,4 @@ Bug Fixes
- Bug in ``groupby`` where ``apply`` returns different result depending on whether first result is ``None`` or not (:issue:`12824`)

- Bug in ``Categorical.remove_unused_categories()`` changes ``.codes`` dtype to platform int (:issue:`13261`)
- Bug in ``pd.read_csv`` in which the ``nrows`` argument was not properly validated for both engines (:issue:`10476`)
24 changes: 22 additions & 2 deletions pandas/io/parsers.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,26 @@
""" % (_parser_params % (_fwf_widths, ''))


def _validate_nrows(nrows):
"""
Checks whether the 'nrows' parameter for parsing is either
an integer OR float that can SAFELY be cast to an integer
without losing accuracy. Raises a ValueError if that is
not the case.
"""
msg = "'nrows' must be an integer"

if nrows is not None:
if com.is_float(nrows):
if int(nrows) != nrows:
raise ValueError(msg)
nrows = int(nrows)
elif not com.is_integer(nrows):
raise ValueError(msg)

return nrows


def _read(filepath_or_buffer, kwds):
"Generic reader of line files."
encoding = kwds.get('encoding', None)
Expand Down Expand Up @@ -311,14 +331,14 @@ def _read(filepath_or_buffer, kwds):

# Extract some of the arguments (pass chunksize on).
iterator = kwds.get('iterator', False)
nrows = kwds.pop('nrows', None)
chunksize = kwds.get('chunksize', None)
nrows = _validate_nrows(kwds.pop('nrows', None))

# Create the parser.
parser = TextFileReader(filepath_or_buffer, **kwds)

if (nrows is not None) and (chunksize is not None):
raise NotImplementedError("'nrows' and 'chunksize' can not be used"
raise NotImplementedError("'nrows' and 'chunksize' cannot be used"
" together yet.")
elif nrows is not None:
return parser.read(nrows)
Expand Down
20 changes: 14 additions & 6 deletions pandas/io/tests/parser/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,10 +391,23 @@ def test_int_conversion(self):
self.assertEqual(data['B'].dtype, np.int64)

def test_read_nrows(self):
df = self.read_csv(StringIO(self.data1), nrows=3)
expected = self.read_csv(StringIO(self.data1))[:3]

df = self.read_csv(StringIO(self.data1), nrows=3)
tm.assert_frame_equal(df, expected)

# see gh-10476
df = self.read_csv(StringIO(self.data1), nrows=3.0)
tm.assert_frame_equal(df, expected)

msg = "must be an integer"

with tm.assertRaisesRegexp(ValueError, msg):
self.read_csv(StringIO(self.data1), nrows=1.2)

with tm.assertRaisesRegexp(ValueError, msg):
self.read_csv(StringIO(self.data1), nrows='foo')

def test_read_chunksize(self):
reader = self.read_csv(StringIO(self.data1), index_col=0, chunksize=2)
df = self.read_csv(StringIO(self.data1), index_col=0)
Expand Down Expand Up @@ -815,11 +828,6 @@ def test_ignore_leading_whitespace(self):
expected = DataFrame({'a': [1, 4, 7], 'b': [2, 5, 8], 'c': [3, 6, 9]})
tm.assert_frame_equal(result, expected)

def test_nrows_and_chunksize_raises_notimplemented(self):
data = 'a b c'
self.assertRaises(NotImplementedError, self.read_csv, StringIO(data),
nrows=10, chunksize=5)

def test_chunk_begins_with_newline_whitespace(self):
# see gh-10022
data = '\n hello\nworld\n'
Expand Down
9 changes: 9 additions & 0 deletions pandas/io/tests/parser/test_unsupported.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,15 @@ def test_mangle_dupe_cols_false(self):
read_csv(StringIO(data), engine=engine,
mangle_dupe_cols=False)

def test_nrows_and_chunksize(self):
data = 'a b c'
msg = "cannot be used together yet"

for engine in ('c', 'python'):
with tm.assertRaisesRegexp(NotImplementedError, msg):
read_csv(StringIO(data), engine=engine,
nrows=10, chunksize=5)

def test_c_engine(self):
# see gh-6607
data = 'a b c\n1 2 3'
Expand Down

0 comments on commit d856051

Please sign in to comment.