diff --git a/cpp/src/io/utilities/datasource.cpp b/cpp/src/io/utilities/datasource.cpp index 5ccc91e4220..0870e4a84a7 100644 --- a/cpp/src/io/utilities/datasource.cpp +++ b/cpp/src/io/utilities/datasource.cpp @@ -95,8 +95,12 @@ class file_source : public datasource { [[nodiscard]] bool is_device_read_preferred(size_t size) const override { - if (size < _gds_read_preferred_threshold) { return false; } - return supports_device_read(); + if (!supports_device_read()) { return false; } + + // Always prefer device reads if kvikio is enabled + if (!_kvikio_file.closed()) { return true; } + + return size >= _gds_read_preferred_threshold; } std::future device_read_async(size_t offset, diff --git a/docs/cudf/source/cudf_polars/engine_options.md b/docs/cudf/source/cudf_polars/engine_options.md new file mode 100644 index 00000000000..9447047123a --- /dev/null +++ b/docs/cudf/source/cudf_polars/engine_options.md @@ -0,0 +1,24 @@ +# GPUEngine Configuration Options + +The `polars.GPUEngine` object may be configured in several different ways. + +## Parquet Reader Options + +Chunked reading is controlled by passing a dictionary of options to the `GPUEngine` object. Details may be found following the links to the underlying `libcudf` reader. +- `parquet_chunked`, indicicates is chunked parquet reading is to be used, default True. +- [chunk_read_limit](https://docs.rapids.ai/api/libcudf/legacy/classcudf_1_1io_1_1chunked__parquet__reader#aad118178b7536b7966e3325ae1143a1a) controls the maximum size per chunk, default unlimited. +- [pass_read_limit](https://docs.rapids.ai/api/libcudf/legacy/classcudf_1_1io_1_1chunked__parquet__reader#aad118178b7536b7966e3325ae1143a1a) controls the maximum memory used for decompression, default 16GiB. + +For example, one would pass these parameters as follows: +```python +engine = GPUEngine( + raise_on_fail=True, + parquet_options={ + 'parquet_chunked': True, + 'chunk_read_limit': int(1e9), + 'pass_read_limit': int(4e9) + } +) +result = query.collect(engine=engine) +``` +Note that passing `parquet_chunked: False` disables chunked reading entirely, and thus `chunk_read_limit` and `pass_read_limit` will have no effect. diff --git a/docs/cudf/source/cudf_polars/index.rst b/docs/cudf/source/cudf_polars/index.rst index 0a3a0d86b2c..6fd98a6b5da 100644 --- a/docs/cudf/source/cudf_polars/index.rst +++ b/docs/cudf/source/cudf_polars/index.rst @@ -39,3 +39,9 @@ Launch on Google Colab :target: https://colab.research.google.com/github/rapidsai-community/showcase/blob/main/accelerated_data_processing_examples/polars_gpu_engine_demo.ipynb Try out the GPU engine for Polars in a free GPU notebook environment. Sign in with your Google account and `launch the demo on Colab `__. + +.. toctree:: + :maxdepth: 1 + :caption: Engine Config Options: + + engine_options diff --git a/python/cudf_polars/cudf_polars/callback.py b/python/cudf_polars/cudf_polars/callback.py index ff4933c7564..b41fa3e13b4 100644 --- a/python/cudf_polars/cudf_polars/callback.py +++ b/python/cudf_polars/cudf_polars/callback.py @@ -129,6 +129,7 @@ def set_device(device: int | None) -> Generator[int, None, None]: def _callback( ir: IR, + config: GPUEngine, with_columns: list[str] | None, pyarrow_predicate: str | None, n_rows: int | None, @@ -145,7 +146,7 @@ def _callback( set_device(device), set_memory_resource(memory_resource), ): - return ir.evaluate(cache={}).to_polars() + return ir.evaluate(cache={}, config=config).to_polars() def execute_with_cudf( @@ -174,7 +175,7 @@ def execute_with_cudf( device = config.device memory_resource = config.memory_resource raise_on_fail = config.config.get("raise_on_fail", False) - if unsupported := (config.config.keys() - {"raise_on_fail"}): + if unsupported := (config.config.keys() - {"raise_on_fail", "parquet_options"}): raise ValueError( f"Engine configuration contains unsupported settings {unsupported}" ) @@ -201,7 +202,11 @@ def execute_with_cudf( else: nt.set_udf( partial( - _callback, ir, device=device, memory_resource=memory_resource + _callback, + ir, + config, + device=device, + memory_resource=memory_resource, ) ) except exception as e: diff --git a/python/cudf_polars/cudf_polars/dsl/ir.py b/python/cudf_polars/cudf_polars/dsl/ir.py index 1f935190f28..7bfca814d0b 100644 --- a/python/cudf_polars/cudf_polars/dsl/ir.py +++ b/python/cudf_polars/cudf_polars/dsl/ir.py @@ -37,6 +37,8 @@ from collections.abc import Callable, Hashable, MutableMapping, Sequence from typing import Literal + from polars import GPUEngine + from cudf_polars.typing import Schema @@ -180,7 +182,9 @@ def get_hashable(self) -> Hashable: translation phase should fail earlier. """ - def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: + def evaluate( + self, *, cache: MutableMapping[int, DataFrame], config: GPUEngine + ) -> DataFrame: """ Evaluate the node (recursively) and return a dataframe. @@ -189,6 +193,8 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: cache Mapping from cached node ids to constructed DataFrames. Used to implement evaluation of the `Cache` node. + config + GPU engine configuration. Notes ----- @@ -208,8 +214,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: translation phase should fail earlier. """ return self.do_evaluate( + config, *self._non_child_args, - *(child.evaluate(cache=cache) for child in self.children), + *(child.evaluate(cache=cache, config=config) for child in self.children), ) @@ -293,6 +300,9 @@ class Scan(IR): predicate: expr.NamedExpr | None """Mask to apply to the read dataframe.""" + PARQUET_DEFAULT_CHUNK_SIZE: int = 0 + PARQUET_DEFAULT_PASS_LIMIT: int = 17179869184 # 16GiB + def __init__( self, schema: Schema, @@ -412,6 +422,7 @@ def get_hashable(self) -> Hashable: @classmethod def do_evaluate( cls, + config: GPUEngine, schema: Schema, typ: str, reader_options: dict[str, Any], @@ -497,25 +508,59 @@ def do_evaluate( colnames[0], ) elif typ == "parquet": - filters = None - if predicate is not None and row_index is None: - # Can't apply filters during read if we have a row index. - filters = to_parquet_filter(predicate.value) - tbl_w_meta = plc.io.parquet.read_parquet( - plc.io.SourceInfo(paths), - columns=with_columns, - filters=filters, - nrows=n_rows, - skip_rows=skip_rows, - ) - df = DataFrame.from_table( - tbl_w_meta.tbl, - # TODO: consider nested column names? - tbl_w_meta.column_names(include_children=False), - ) - if filters is not None: - # Mask must have been applied. - return df + parquet_options = config.config.get("parquet_options", {}) + if parquet_options.get("chunked", False): + reader = plc.io.parquet.ChunkedParquetReader( + plc.io.SourceInfo(paths), + columns=with_columns, + nrows=n_rows, + skip_rows=skip_rows, + chunk_read_limit=parquet_options.get( + "chunk_read_limit", cls.PARQUET_DEFAULT_CHUNK_SIZE + ), + pass_read_limit=parquet_options.get( + "pass_read_limit", cls.PARQUET_DEFAULT_PASS_LIMIT + ), + ) + chk = reader.read_chunk() + tbl = chk.tbl + names = chk.column_names() + concatenated_columns = tbl.columns() + while reader.has_next(): + tbl = reader.read_chunk().tbl + + for i in range(tbl.num_columns()): + concatenated_columns[i] = plc.concatenate.concatenate( + [concatenated_columns[i], tbl._columns[i]] + ) + # Drop residual columns to save memory + tbl._columns[i] = None + + df = DataFrame.from_table( + plc.Table(concatenated_columns), + names=names, + ) + else: + filters = None + if predicate is not None and row_index is None: + # Can't apply filters during read if we have a row index. + filters = to_parquet_filter(predicate.value) + tbl_w_meta = plc.io.parquet.read_parquet( + plc.io.SourceInfo(paths), + columns=with_columns, + filters=filters, + nrows=n_rows, + skip_rows=skip_rows, + ) + df = DataFrame.from_table( + tbl_w_meta.tbl, + # TODO: consider nested column names? + tbl_w_meta.column_names(include_children=False), + ) + if filters is not None: + # Mask must have been applied. + return df + elif typ == "ndjson": json_schema: list[plc.io.json.NameAndType] = [ (name, typ, []) for name, typ in schema.items() @@ -590,14 +635,16 @@ def __init__(self, schema: Schema, key: int, value: IR): @classmethod def do_evaluate( - cls, key: int, df: DataFrame + cls, config: GPUEngine, key: int, df: DataFrame ) -> DataFrame: # pragma: no cover; basic evaluation never calls this """Evaluate and return a dataframe.""" # Our value has already been computed for us, so let's just # return it. return df - def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: + def evaluate( + self, *, cache: MutableMapping[int, DataFrame], config: GPUEngine + ) -> DataFrame: """Evaluate and return a dataframe.""" # We must override the recursion scheme because we don't want # to recurse if we're in the cache. @@ -605,7 +652,9 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame: return cache[self.key] except KeyError: (value,) = self.children - return cache.setdefault(self.key, value.evaluate(cache=cache)) + return cache.setdefault( + self.key, value.evaluate(cache=cache, config=config) + ) class DataFrameScan(IR): @@ -651,6 +700,7 @@ def get_hashable(self) -> Hashable: @classmethod def do_evaluate( cls, + config: GPUEngine, schema: Schema, df: Any, projection: tuple[str, ...] | None, @@ -698,6 +748,7 @@ def __init__( @classmethod def do_evaluate( cls, + config: GPUEngine, exprs: tuple[expr.NamedExpr, ...], should_broadcast: bool, # noqa: FBT001 df: DataFrame, @@ -732,7 +783,10 @@ def __init__( @classmethod def do_evaluate( - cls, exprs: tuple[expr.NamedExpr, ...], df: DataFrame + cls, + config: GPUEngine, + exprs: tuple[expr.NamedExpr, ...], + df: DataFrame, ) -> DataFrame: # pragma: no cover; not exposed by polars yet """Evaluate and return a dataframe.""" columns = broadcast(*(e.evaluate(df) for e in exprs)) @@ -823,6 +877,7 @@ def check_agg(agg: expr.Expr) -> int: @classmethod def do_evaluate( cls, + config: GPUEngine, keys_in: Sequence[expr.NamedExpr], agg_requests: Sequence[expr.NamedExpr], maintain_order: bool, # noqa: FBT001 @@ -944,6 +999,7 @@ def __init__( @classmethod def do_evaluate( cls, + config: GPUEngine, predicate: plc.expressions.Expression, zlice: tuple[int, int] | None, suffix: str, @@ -1116,6 +1172,7 @@ def _reorder_maps( @classmethod def do_evaluate( cls, + config: GPUEngine, left_on_exprs: Sequence[expr.NamedExpr], right_on_exprs: Sequence[expr.NamedExpr], options: tuple[ @@ -1239,6 +1296,7 @@ def __init__( @classmethod def do_evaluate( cls, + config: GPUEngine, exprs: Sequence[expr.NamedExpr], should_broadcast: bool, # noqa: FBT001 df: DataFrame, @@ -1301,6 +1359,7 @@ def __init__( @classmethod def do_evaluate( cls, + config: GPUEngine, keep: plc.stream_compaction.DuplicateKeepOption, subset: frozenset[str] | None, zlice: tuple[int, int] | None, @@ -1390,6 +1449,7 @@ def __init__( @classmethod def do_evaluate( cls, + config: GPUEngine, by: Sequence[expr.NamedExpr], order: Sequence[plc.types.Order], null_order: Sequence[plc.types.NullOrder], @@ -1445,7 +1505,9 @@ def __init__(self, schema: Schema, offset: int, length: int, df: IR): self.children = (df,) @classmethod - def do_evaluate(cls, offset: int, length: int, df: DataFrame) -> DataFrame: + def do_evaluate( + cls, config: GPUEngine, offset: int, length: int, df: DataFrame + ) -> DataFrame: """Evaluate and return a dataframe.""" return df.slice((offset, length)) @@ -1465,7 +1527,9 @@ def __init__(self, schema: Schema, mask: expr.NamedExpr, df: IR): self.children = (df,) @classmethod - def do_evaluate(cls, mask_expr: expr.NamedExpr, df: DataFrame) -> DataFrame: + def do_evaluate( + cls, config: GPUEngine, mask_expr: expr.NamedExpr, df: DataFrame + ) -> DataFrame: """Evaluate and return a dataframe.""" (mask,) = broadcast(mask_expr.evaluate(df), target_length=df.num_rows) return df.filter(mask) @@ -1483,7 +1547,7 @@ def __init__(self, schema: Schema, df: IR): self.children = (df,) @classmethod - def do_evaluate(cls, schema: Schema, df: DataFrame) -> DataFrame: + def do_evaluate(cls, config: GPUEngine, schema: Schema, df: DataFrame) -> DataFrame: """Evaluate and return a dataframe.""" # This can reorder things. columns = broadcast( @@ -1559,7 +1623,9 @@ def __init__(self, schema: Schema, name: str, options: Any, df: IR): self._non_child_args = (name, self.options) @classmethod - def do_evaluate(cls, name: str, options: Any, df: DataFrame) -> DataFrame: + def do_evaluate( + cls, config: GPUEngine, name: str, options: Any, df: DataFrame + ) -> DataFrame: """Evaluate and return a dataframe.""" if name == "rechunk": # No-op in our data model @@ -1638,7 +1704,9 @@ def __init__(self, schema: Schema, zlice: tuple[int, int] | None, *children: IR) raise NotImplementedError("Schema mismatch") @classmethod - def do_evaluate(cls, zlice: tuple[int, int] | None, *dfs: DataFrame) -> DataFrame: + def do_evaluate( + cls, config: GPUEngine, zlice: tuple[int, int] | None, *dfs: DataFrame + ) -> DataFrame: """Evaluate and return a dataframe.""" # TODO: only evaluate what we need if we have a slice? return DataFrame.from_table( @@ -1687,7 +1755,7 @@ def _extend_with_nulls(table: plc.Table, *, nrows: int) -> plc.Table: ) @classmethod - def do_evaluate(cls, *dfs: DataFrame) -> DataFrame: + def do_evaluate(cls, config: GPUEngine, *dfs: DataFrame) -> DataFrame: """Evaluate and return a dataframe.""" max_rows = max(df.num_rows for df in dfs) # Horizontal concatenation extends shorter tables with nulls diff --git a/python/cudf_polars/tests/dsl/test_to_ast.py b/python/cudf_polars/tests/dsl/test_to_ast.py index f6c24da0180..795ba991c62 100644 --- a/python/cudf_polars/tests/dsl/test_to_ast.py +++ b/python/cudf_polars/tests/dsl/test_to_ast.py @@ -63,7 +63,7 @@ def test_compute_column(expr, df): ir = Translator(q._ldf.visit()).translate_ir() assert isinstance(ir, ir_nodes.Select) - table = ir.children[0].evaluate(cache={}) + table = ir.children[0].evaluate(cache={}, config=pl.GPUEngine()) name_to_index = {c.name: i for i, c in enumerate(table.columns)} def compute_column(e): diff --git a/python/cudf_polars/tests/dsl/test_traversal.py b/python/cudf_polars/tests/dsl/test_traversal.py index 8958c2a0f84..8849629e0fd 100644 --- a/python/cudf_polars/tests/dsl/test_traversal.py +++ b/python/cudf_polars/tests/dsl/test_traversal.py @@ -124,7 +124,7 @@ def replace_df(node, rec): new = mapper(orig) - result = new.evaluate(cache={}).to_polars() + result = new.evaluate(cache={}, config=pl.GPUEngine()).to_polars() expect = pl.DataFrame({"a": [2, 1], "b": [-4, -3]}) @@ -153,7 +153,7 @@ def replace_scan(node, rec): orig = Translator(q._ldf.visit()).translate_ir() new = mapper(orig) - result = new.evaluate(cache={}).to_polars() + result = new.evaluate(cache={}, config=pl.GPUEngine()).to_polars() expect = q.collect() @@ -224,6 +224,6 @@ def _(node: ir.Select, fn: IRTransformer): new_ir = rewriter(qir) - got = new_ir.evaluate(cache={}).to_polars() + got = new_ir.evaluate(cache={}, config=pl.GPUEngine()).to_polars() assert_frame_equal(expect, got) diff --git a/python/cudf_polars/tests/expressions/test_sort.py b/python/cudf_polars/tests/expressions/test_sort.py index 6170281ad54..49e075e0338 100644 --- a/python/cudf_polars/tests/expressions/test_sort.py +++ b/python/cudf_polars/tests/expressions/test_sort.py @@ -68,7 +68,11 @@ def test_setsorted(descending, nulls_last, with_nulls): assert_gpu_result_equal(q) - df = Translator(q._ldf.visit()).translate_ir().evaluate(cache={}) + df = ( + Translator(q._ldf.visit()) + .translate_ir() + .evaluate(cache={}, config=pl.GPUEngine()) + ) a = df.column_map["a"]