diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 3e1a4b1f024..6c91a5a8765 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -1,14 +1,18 @@ # Copyright (c) 2019-2023, NVIDIA CORPORATION. +from __future__ import annotations import math +import operator import shutil import tempfile import warnings from collections import defaultdict from contextlib import ExitStack -from typing import Dict, List, Optional, Tuple +from functools import partial, reduce +from typing import Callable, Dict, List, Optional, Tuple from uuid import uuid4 +import numpy as np import pandas as pd from pyarrow import dataset as ds, parquet as pq @@ -481,6 +485,9 @@ def read_parquet( path_or_data=filepath_or_buffer, storage_options=storage_options ) + # Normalize and validate filters + filters = _normalize_filters(filters) + # Use pyarrow dataset to detect/process directory-partitioned # data and apply filters. Note that we can only support partitioned # data and filtering if the input is a single directory or list of @@ -501,8 +508,6 @@ def read_parquet( categorical_partitions=categorical_partitions, dataset_kwargs=dataset_kwargs, ) - elif filters is not None: - raise ValueError("cudf cannot apply filters to open file objects.") filepath_or_buffer = paths if paths else filepath_or_buffer filepaths_or_buffers = [] @@ -547,7 +552,8 @@ def read_parquet( "for full CPU-based filtering functionality." ) - return _parquet_to_frame( + # Convert parquet data to a cudf.DataFrame + df = _parquet_to_frame( filepaths_or_buffers, engine, *args, @@ -561,6 +567,115 @@ def read_parquet( **kwargs, ) + # Apply filters row-wise (if any are defined), and return + return _apply_post_filters(df, filters) + + +def _normalize_filters(filters: list | None) -> List[List[tuple]] | None: + # Utility to normalize and validate the `filters` + # argument to `read_parquet` + if not filters: + return None + + msg = ( + f"filters must be None, or non-empty List[Tuple] " + f"or List[List[Tuple]]. Got {filters}" + ) + if not isinstance(filters, list): + raise TypeError(msg) + + def _validate_predicate(item): + if not isinstance(item, tuple) or len(item) != 3: + raise TypeError( + f"Predicate must be Tuple[str, str, Any], " f"got {predicate}." + ) + + filters = filters if isinstance(filters[0], list) else [filters] + for conjunction in filters: + if not conjunction or not isinstance(conjunction, list): + raise TypeError(msg) + for predicate in conjunction: + _validate_predicate(predicate) + + return filters + + +def _apply_post_filters( + df: cudf.DataFrame, filters: List[List[tuple]] | None +) -> cudf.DataFrame: + """Apply DNF filters to an in-memory DataFrame + + Disjunctive normal form (DNF) means that the inner-most + tuple describes a single column predicate. These inner + predicates are combined with an AND conjunction into a + larger predicate. The outer-most list then combines all + of the combined filters with an OR disjunction. + """ + + if not filters: + # No filters to apply + return df + + def _handle_in(column: cudf.Series, value, *, negate) -> cudf.Series: + if not isinstance(value, (list, set, tuple)): + raise TypeError( + "Value of 'in'/'not in' filter must be a list, set, or tuple." + ) + return ~column.isin(value) if negate else column.isin(value) + + def _handle_is(column: cudf.Series, value, *, negate) -> cudf.Series: + if value not in {np.nan, None}: + raise TypeError( + "Value of 'is'/'is not' filter must be np.nan or None." + ) + return ~column.isna() if negate else column.isna() + + handlers: Dict[str, Callable] = { + "==": operator.eq, + "!=": operator.ne, + "<": operator.lt, + "<=": operator.le, + ">": operator.gt, + ">=": operator.ge, + "in": partial(_handle_in, negate=False), + "not in": partial(_handle_in, negate=True), + "is": partial(_handle_is, negate=False), + "is not": partial(_handle_is, negate=True), + } + + # Can re-set the index before returning if we filter + # out rows from a DataFrame with a default RangeIndex + # (to reduce memory usage) + reset_index = ( + isinstance(df.index, cudf.RangeIndex) + and df.index.name is None + and df.index.start == 0 + and df.index.step == 1 + ) + + try: + selection: cudf.Series = reduce( + operator.or_, + ( + reduce( + operator.and_, + ( + handlers[op](df[column], value) + for (column, op, value) in expr + ), + ) + for expr in filters + ), + ) + if reset_index: + return df[selection].reset_index(drop=True) + return df[selection] + except (KeyError, TypeError): + warnings.warn( + f"Row-wise filtering failed in read_parquet for {filters}" + ) + return df + @_cudf_nvtx_annotate def _parquet_to_frame( diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 0ab5d35f9f8..7aeb2afc2c7 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -528,9 +528,7 @@ def test_parquet_read_filtered_multiple_files(tmpdir): ) assert_eq( filtered_df, - cudf.DataFrame( - {"x": [2, 3, 2, 3], "y": list("bbcc")}, index=[2, 3, 2, 3] - ), + cudf.DataFrame({"x": [2, 2], "y": list("bc")}, index=[2, 2]), ) @@ -541,13 +539,16 @@ def test_parquet_read_filtered_multiple_files(tmpdir): @pytest.mark.parametrize( "predicate,expected_len", [ - ([[("x", "==", 0)], [("z", "==", 0)]], 4), - ([("x", "==", 0), ("z", "==", 0)], 0), - ([("x", "==", 0), ("z", "!=", 0)], 2), + ([[("x", "==", 0)], [("z", "==", 0)]], 2), ([("x", "==", 0), ("z", "==", 0)], 0), + ([("x", "==", 0), ("z", "!=", 0)], 1), ([("y", "==", "c"), ("x", ">", 8)], 0), - ([("y", "==", "c"), ("x", ">=", 5)], 2), - ([[("y", "==", "c")], [("x", "<", 3)]], 6), + ([("y", "==", "c"), ("x", ">=", 5)], 1), + ([[("y", "==", "c")], [("x", "<", 3)]], 5), + ([[("x", "not in", (0, 9)), ("z", "not in", (4, 5))]], 6), + ([[("y", "==", "c")], [("x", "in", (0, 9)), ("z", "in", (0, 9))]], 4), + ([[("x", "==", 0)], [("x", "==", 1)], [("x", "==", 2)]], 3), + ([[("x", "==", 0), ("z", "==", 9), ("y", "==", "a")]], 1), ], ) def test_parquet_read_filtered_complex_predicate( @@ -556,7 +557,11 @@ def test_parquet_read_filtered_complex_predicate( # Generate data fname = tmpdir.join("filtered_complex_predicate.parquet") df = pd.DataFrame( - {"x": range(10), "y": list("aabbccddee"), "z": reversed(range(10))} + { + "x": range(10), + "y": list("aabbccddee"), + "z": reversed(range(10)), + } ) df.to_parquet(fname, row_group_size=2) @@ -1954,26 +1959,16 @@ def test_read_parquet_partitioned_filtered( assert got.dtypes["c"] == "int" assert_eq(expect, got) - # Filter on non-partitioned column. - # Cannot compare to pandas, since the pyarrow - # backend will filter by row (and cudf can - # only filter by column, for now) + # Filter on non-partitioned column filters = [("a", "==", 10)] - got = cudf.read_parquet( - read_path, - filters=filters, - row_groups=row_groups, - ) - assert len(got) < len(df) and 10 in got["a"] + got = cudf.read_parquet(read_path, filters=filters) + expect = pd.read_parquet(read_path, filters=filters) # Filter on both kinds of columns filters = [[("a", "==", 10)], [("c", "==", 1)]] - got = cudf.read_parquet( - read_path, - filters=filters, - row_groups=row_groups, - ) - assert len(got) < len(df) and (1 in got["c"] and 10 in got["a"]) + got = cudf.read_parquet(read_path, filters=filters) + expect = pd.read_parquet(read_path, filters=filters) + assert_eq(expect, got) def test_parquet_writer_chunked_metadata(tmpdir, simple_pdf, simple_gdf): diff --git a/python/cudf/cudf/utils/ioutils.py b/python/cudf/cudf/utils/ioutils.py index bf51b360fec..bf9dc226d11 100644 --- a/python/cudf/cudf/utils/ioutils.py +++ b/python/cudf/cudf/utils/ioutils.py @@ -147,11 +147,12 @@ For other URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more details. -filters : list of tuple, list of lists of tuples default None +filters : list of tuple, list of lists of tuples, default None If not None, specifies a filter predicate used to filter out row groups using statistics stored for each row group as Parquet metadata. Row groups - that do not match the given filter predicate are not read. The - predicate is expressed in disjunctive normal form (DNF) like + that do not match the given filter predicate are not read. The filters + will also be applied to the rows of the in-memory DataFrame after IO. + The predicate is expressed in disjunctive normal form (DNF) like `[[('x', '=', 0), ...], ...]`. DNF allows arbitrary boolean logical combinations of single column predicates. The innermost tuples each describe a single column predicate. The list of inner predicates is diff --git a/python/dask_cudf/dask_cudf/io/parquet.py b/python/dask_cudf/dask_cudf/io/parquet.py index 67fc7215fc7..bf0f115f310 100644 --- a/python/dask_cudf/dask_cudf/io/parquet.py +++ b/python/dask_cudf/dask_cudf/io/parquet.py @@ -20,7 +20,11 @@ import cudf from cudf.core.column import as_column, build_categorical_column from cudf.io import write_to_dataset -from cudf.io.parquet import _default_open_file_options +from cudf.io.parquet import ( + _apply_post_filters, + _default_open_file_options, + _normalize_filters, +) from cudf.utils.dtypes import cudf_dtype_from_pa_type from cudf.utils.ioutils import ( _ROW_GROUP_SIZE_BYTES_DEFAULT, @@ -69,6 +73,7 @@ def _read_paths( fs, columns=None, row_groups=None, + filters=None, strings_to_categorical=None, partitions=None, partitioning=None, @@ -134,6 +139,10 @@ def _read_paths( else: raise err + # Apply filters (if any are defined) + filters = _normalize_filters(filters) + df = _apply_post_filters(df, filters) + if partitions and partition_keys is None: # Use `HivePartitioning` by default @@ -183,6 +192,7 @@ def read_partition( index, categories=(), partitions=(), + filters=None, partitioning=None, schema=None, open_file_options=None, @@ -255,6 +265,7 @@ def read_partition( fs, columns=read_columns, row_groups=rgs if rgs else None, + filters=filters, strings_to_categorical=strings_to_cats, partitions=partitions, partitioning=partitioning, @@ -281,6 +292,7 @@ def read_partition( fs, columns=read_columns, row_groups=rgs if rgs else None, + filters=filters, strings_to_categorical=strings_to_cats, partitions=partitions, partitioning=partitioning, diff --git a/python/dask_cudf/dask_cudf/io/tests/test_parquet.py b/python/dask_cudf/dask_cudf/io/tests/test_parquet.py index f5ae9706fde..8e80aad67d1 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_parquet.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_parquet.py @@ -254,6 +254,41 @@ def test_filters(tmpdir): assert not len(c) +@pytest.mark.parametrize("numeric", [True, False]) +@pytest.mark.parametrize("null", [np.nan, None]) +def test_isna_filters(tmpdir, null, numeric): + + tmp_path = str(tmpdir) + df = pd.DataFrame( + { + "x": range(10), + "y": list("aabbccddee"), + "i": [0] * 4 + [np.nan] * 2 + [0] * 4, + "j": [""] * 4 + [None] * 2 + [""] * 4, + } + ) + ddf = dd.from_pandas(df, npartitions=5) + assert ddf.npartitions == 5 + ddf.to_parquet(tmp_path, engine="pyarrow") + + # Test "is" + col = "i" if numeric else "j" + filters = [(col, "is", null)] + out = dask_cudf.read_parquet( + tmp_path, filters=filters, split_row_groups=True + ) + assert len(out) == 2 + assert list(out.x.compute().values) == [4, 5] + + # Test "is not" + filters = [(col, "is not", null)] + out = dask_cudf.read_parquet( + tmp_path, filters=filters, split_row_groups=True + ) + assert len(out) == 8 + assert list(out.x.compute().values) == [0, 1, 2, 3, 6, 7, 8, 9] + + def test_filters_at_row_group_level(tmpdir): tmp_path = str(tmpdir) @@ -267,7 +302,7 @@ def test_filters_at_row_group_level(tmpdir): tmp_path, filters=[("x", "==", 1)], split_row_groups=True ) assert a.npartitions == 1 - assert (a.shape[0] == 2).compute() + assert (a.shape[0] == 1).compute() ddf.to_parquet(tmp_path, engine="pyarrow", row_group_size=1)