diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 0442020cef3..8bc6d0ea9dc 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -190,15 +190,14 @@ class Scan(IR): """Cloud-related authentication options, currently ignored.""" paths: list[str] """List of paths to read from.""" - file_options: Any - """Options for reading the file. - - Attributes are: - - ``with_columns: list[str]`` of projected columns to return. - - ``n_rows: int``: Number of rows to read. - - ``row_index: tuple[name, offset] | None``: Add an integer index - column with given name. - """ + with_columns: list[str] + """Projected columns to return.""" + skip_rows: int + """Rows to skip at the start when reading.""" + n_rows: int + """Number of rows to read after skipping.""" + row_index: tuple[str, int] | None + """If not None add an integer index column of the given name.""" predicate: expr.NamedExpr | None """Mask to apply to the read dataframe.""" @@ -208,8 +207,16 @@ def __post_init__(self) -> None: # This line is unhittable ATM since IPC/Anonymous scan raise # on the polars side raise NotImplementedError(f"Unhandled scan type: {self.typ}") - if self.typ == "ndjson" and self.file_options.n_rows is not None: - raise NotImplementedError("row limit in scan") + if self.typ == "ndjson" and (self.n_rows != -1 or self.skip_rows != 0): + raise NotImplementedError("row limit in scan for json reader") + if self.skip_rows < 0: + # TODO: polars has this implemented for parquet, + # maybe we can do this too? + raise NotImplementedError("slice pushdown for negative slices") + if self.typ == "csv" and self.skip_rows != 0: # pragma: no cover + # This comes from slice pushdown, but that + # optimization doesn't happen right now + raise NotImplementedError("skipping rows in CSV reader") if self.cloud_options is not None and any( self.cloud_options.get(k) is not None for k in ("aws", "azure", "gcp") ): @@ -246,10 +253,9 @@ def __post_init__(self) -> None: def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: """Evaluate and return a dataframe.""" - options = self.file_options - with_columns = options.with_columns - row_index = options.row_index - nrows = self.file_options.n_rows if self.file_options.n_rows is not None else -1 + with_columns = self.with_columns + row_index = self.row_index + n_rows = self.n_rows if self.typ == "csv": parse_options = self.reader_options["parse_options"] sep = chr(parse_options["separator"]) @@ -283,6 +289,7 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: # polars skips blank lines at the beginning of the file pieces = [] + read_partial = n_rows != -1 for p in self.paths: skiprows = self.reader_options["skip_rows"] path = Path(p) @@ -304,9 +311,13 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: comment=comment, decimal=decimal, dtypes=self.schema, - nrows=nrows, + nrows=n_rows, ) pieces.append(tbl_w_meta) + if read_partial: + n_rows -= tbl_w_meta.tbl.num_rows() + if n_rows <= 0: + break tables, colnames = zip( *( (piece.tbl, piece.column_names(include_children=False)) @@ -321,7 +332,8 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: tbl_w_meta = plc.io.parquet.read_parquet( plc.io.SourceInfo(self.paths), columns=with_columns, - num_rows=nrows, + num_rows=n_rows, + skip_rows=self.skip_rows, ) df = DataFrame.from_table( tbl_w_meta.tbl, diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index 6601d2d29e8..66685d0cefd 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -76,7 +76,7 @@ def _translate_ir( def _( node: pl_ir.PythonScan, visitor: NodeTraverser, schema: dict[str, plc.DataType] ) -> ir.IR: - if visitor.version()[0] == 1: # pragma: no cover + if visitor.version()[0] == 1: # https://github.com/pola-rs/polars/pull/17939 # Versioning can be dropped once polars 1.4 is lowest # supported version. @@ -87,7 +87,7 @@ def _( if predicate is not None else None ) - else: + else: # pragma: no cover; CI tests 1.4 # version == 0 options = node.options predicate = ( @@ -108,13 +108,32 @@ def _( cloud_options = None else: reader_options, cloud_options = map(json.loads, options) + file_options = node.file_options + with_columns = file_options.with_columns + n_rows = file_options.n_rows + if n_rows is None: + n_rows = -1 # All rows + skip_rows = 0 # Don't skip + else: + if visitor.version() >= (1, 0): + # Polars 1.4 n_rows property is (skip, nrows) + skip_rows, n_rows = n_rows + else: # pragma: no cover; CI tests 1.4 + # Polars 1.3 n_rows property is integer, skip rows was + # always zero because it was not pushed down to reader. + skip_rows = 0 + + row_index = file_options.row_index return ir.Scan( schema, typ, reader_options, cloud_options, node.paths, - node.file_options, + with_columns, + skip_rows, + n_rows, + row_index, translate_named_expr(visitor, n=node.predicate) if node.predicate is not None else None, diff --git a/python/cudf_polars/tests/test_scan.py b/python/cudf_polars/tests/test_scan.py index e92787238e9..17b2802a479 100644 --- a/python/cudf_polars/tests/test_scan.py +++ b/python/cudf_polars/tests/test_scan.py @@ -57,6 +57,22 @@ def mask(request): return request.param +@pytest.fixture( + params=[ + None, + (1, 1), + ], + ids=[ + "no-slice", + "slice-second", + ], +) +def slice(request): + # For use in testing that we handle + # polars slice pushdown correctly + return request.param + + def make_source(df, path, format): """ Writes the passed polars df to a file of @@ -78,7 +94,9 @@ def make_source(df, path, format): ("parquet", pl.scan_parquet), ], ) -def test_scan(tmp_path, df, format, scan_fn, row_index, n_rows, columns, mask, request): +def test_scan( + tmp_path, df, format, scan_fn, row_index, n_rows, columns, mask, slice, request +): name, offset = row_index make_source(df, tmp_path / "file", format) request.applymarker( @@ -93,6 +111,8 @@ def test_scan(tmp_path, df, format, scan_fn, row_index, n_rows, columns, mask, r row_index_offset=offset, n_rows=n_rows, ) + if slice is not None: + q = q.slice(*slice) if mask is not None: q = q.filter(mask) if columns is not None: @@ -100,6 +120,16 @@ def test_scan(tmp_path, df, format, scan_fn, row_index, n_rows, columns, mask, r assert_gpu_result_equal(q) +def test_negative_slice_pushdown_raises(tmp_path): + df = pl.DataFrame({"a": [1, 2, 3]}) + + df.write_parquet(tmp_path / "df.parquet") + q = pl.scan_parquet(tmp_path / "df.parquet") + # Take the last row + q = q.slice(-1, 1) + assert_ir_translation_raises(q, NotImplementedError) + + def test_scan_unsupported_raises(tmp_path): df = pl.DataFrame({"a": [1, 2, 3]}) @@ -154,15 +184,25 @@ def test_scan_csv_column_renames_projection_schema(tmp_path): ("test*.csv", False), ], ) -def test_scan_csv_multi(tmp_path, filename, glob): +@pytest.mark.parametrize( + "nrows_skiprows", + [ + (None, 0), + (1, 1), + (3, 0), + (4, 2), + ], +) +def test_scan_csv_multi(tmp_path, filename, glob, nrows_skiprows): + n_rows, skiprows = nrows_skiprows with (tmp_path / "test1.csv").open("w") as f: - f.write("""foo,bar,baz\n1,2\n3,4,5""") + f.write("""foo,bar,baz\n1,2,3\n3,4,5""") with (tmp_path / "test2.csv").open("w") as f: - f.write("""foo,bar,baz\n1,2\n3,4,5""") + f.write("""foo,bar,baz\n1,2,3\n3,4,5""") with (tmp_path / "test*.csv").open("w") as f: - f.write("""foo,bar,baz\n1,2\n3,4,5""") + f.write("""foo,bar,baz\n1,2,3\n3,4,5""") os.chdir(tmp_path) - q = pl.scan_csv(filename, glob=glob) + q = pl.scan_csv(filename, glob=glob, n_rows=n_rows, skip_rows=skiprows) assert_gpu_result_equal(q)