Skip to content

Commit

Permalink
Add row-wise filtering step to read_parquet (#13334)
Browse files Browse the repository at this point in the history
This PR adds a `post_filters` argument to `cudf.read_parquet`, which is set equal to the `filters` argument by default. When this argument is set, the specified DNF (disjunctive normal form) filter expression will be applied to the in-memory `cudf.DataFrame` object after IO is performed.

The overal result of this PR is that the behavior of `cudf.read_parquet` becomes more consistent with that of `pd.read_parquet` in the sense that the default result will now enforce filters at a row-wise granularity for both libraries.

### Note on the "need" for distinct `filters` and `post_filters` arguments

My hope is that `post_filters` will eventually go away. However, I added a distinct argument for two general reasons:

1. PyArrow does not yet support `"is"` and `"is not"` operands in `filters`.  Therefore, we can not pass along **all** filters from `dask`/`dask_cudf` down to `cudf.read_parquet` using the existing `filters` argument, because we rely on pyarrow to filter out row-groups (note that dask implements its own filter-conversion utility to avoid this problem). I'm hoping pyarrow will eventually adopt these comparison types (xref: apache/arrow#34504)
2. When `cudf.read_parquet` is called from `dask_cudf.read_parquet`, row-group filtering will have already been applied. Therefore, it is convenient to specify that you only need cudf to provide the post-IO row-wise filtering step. Otherwise, we are effectively duplicating some metadata processing.

My primary concern with adding `post_filters` is the idea that row-wise filtering *could* be added at the cuio/libcudf level in the future. In that (hypothetical) case, `post_filters` wouldn't really be providing any value, but we would probably  be able to deprecate it without much pain (if any).

## Additional Context

This feature is ultimately **needed** to support general predicate-pushdown optimizations in Dask Expressions (dask-expr). This is because the high-level optimization logic becomes much simpler when a filter-based operation of a `ReadParquet` expression can be iteratively "absorbed" into the root `ReadParquet` expression.

Authors:
  - Richard (Rick) Zamora (https://github.com/rjzamora)

Approvers:
  - Bradley Dice (https://github.com/bdice)
  - Lawrence Mitchell (https://github.com/wence-)

URL: #13334
  • Loading branch information
rjzamora authored May 17, 2023
1 parent da09157 commit 1089997
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 34 deletions.
123 changes: 119 additions & 4 deletions python/cudf/cudf/io/parquet.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# Copyright (c) 2019-2023, NVIDIA CORPORATION.
from __future__ import annotations

import math
import operator
import shutil
import tempfile
import warnings
from collections import defaultdict
from contextlib import ExitStack
from typing import Dict, List, Optional, Tuple
from functools import partial, reduce
from typing import Callable, Dict, List, Optional, Tuple
from uuid import uuid4

import numpy as np
import pandas as pd
from pyarrow import dataset as ds, parquet as pq

Expand Down Expand Up @@ -481,6 +485,9 @@ def read_parquet(
path_or_data=filepath_or_buffer, storage_options=storage_options
)

# Normalize and validate filters
filters = _normalize_filters(filters)

# Use pyarrow dataset to detect/process directory-partitioned
# data and apply filters. Note that we can only support partitioned
# data and filtering if the input is a single directory or list of
Expand All @@ -501,8 +508,6 @@ def read_parquet(
categorical_partitions=categorical_partitions,
dataset_kwargs=dataset_kwargs,
)
elif filters is not None:
raise ValueError("cudf cannot apply filters to open file objects.")
filepath_or_buffer = paths if paths else filepath_or_buffer

filepaths_or_buffers = []
Expand Down Expand Up @@ -547,7 +552,8 @@ def read_parquet(
"for full CPU-based filtering functionality."
)

return _parquet_to_frame(
# Convert parquet data to a cudf.DataFrame
df = _parquet_to_frame(
filepaths_or_buffers,
engine,
*args,
Expand All @@ -561,6 +567,115 @@ def read_parquet(
**kwargs,
)

# Apply filters row-wise (if any are defined), and return
return _apply_post_filters(df, filters)


def _normalize_filters(filters: list | None) -> List[List[tuple]] | None:
# Utility to normalize and validate the `filters`
# argument to `read_parquet`
if not filters:
return None

msg = (
f"filters must be None, or non-empty List[Tuple] "
f"or List[List[Tuple]]. Got {filters}"
)
if not isinstance(filters, list):
raise TypeError(msg)

def _validate_predicate(item):
if not isinstance(item, tuple) or len(item) != 3:
raise TypeError(
f"Predicate must be Tuple[str, str, Any], " f"got {predicate}."
)

filters = filters if isinstance(filters[0], list) else [filters]
for conjunction in filters:
if not conjunction or not isinstance(conjunction, list):
raise TypeError(msg)
for predicate in conjunction:
_validate_predicate(predicate)

return filters


def _apply_post_filters(
df: cudf.DataFrame, filters: List[List[tuple]] | None
) -> cudf.DataFrame:
"""Apply DNF filters to an in-memory DataFrame
Disjunctive normal form (DNF) means that the inner-most
tuple describes a single column predicate. These inner
predicates are combined with an AND conjunction into a
larger predicate. The outer-most list then combines all
of the combined filters with an OR disjunction.
"""

if not filters:
# No filters to apply
return df

def _handle_in(column: cudf.Series, value, *, negate) -> cudf.Series:
if not isinstance(value, (list, set, tuple)):
raise TypeError(
"Value of 'in'/'not in' filter must be a list, set, or tuple."
)
return ~column.isin(value) if negate else column.isin(value)

def _handle_is(column: cudf.Series, value, *, negate) -> cudf.Series:
if value not in {np.nan, None}:
raise TypeError(
"Value of 'is'/'is not' filter must be np.nan or None."
)
return ~column.isna() if negate else column.isna()

handlers: Dict[str, Callable] = {
"==": operator.eq,
"!=": operator.ne,
"<": operator.lt,
"<=": operator.le,
">": operator.gt,
">=": operator.ge,
"in": partial(_handle_in, negate=False),
"not in": partial(_handle_in, negate=True),
"is": partial(_handle_is, negate=False),
"is not": partial(_handle_is, negate=True),
}

# Can re-set the index before returning if we filter
# out rows from a DataFrame with a default RangeIndex
# (to reduce memory usage)
reset_index = (
isinstance(df.index, cudf.RangeIndex)
and df.index.name is None
and df.index.start == 0
and df.index.step == 1
)

try:
selection: cudf.Series = reduce(
operator.or_,
(
reduce(
operator.and_,
(
handlers[op](df[column], value)
for (column, op, value) in expr
),
)
for expr in filters
),
)
if reset_index:
return df[selection].reset_index(drop=True)
return df[selection]
except (KeyError, TypeError):
warnings.warn(
f"Row-wise filtering failed in read_parquet for {filters}"
)
return df


@_cudf_nvtx_annotate
def _parquet_to_frame(
Expand Down
45 changes: 20 additions & 25 deletions python/cudf/cudf/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,9 +528,7 @@ def test_parquet_read_filtered_multiple_files(tmpdir):
)
assert_eq(
filtered_df,
cudf.DataFrame(
{"x": [2, 3, 2, 3], "y": list("bbcc")}, index=[2, 3, 2, 3]
),
cudf.DataFrame({"x": [2, 2], "y": list("bc")}, index=[2, 2]),
)


Expand All @@ -541,13 +539,16 @@ def test_parquet_read_filtered_multiple_files(tmpdir):
@pytest.mark.parametrize(
"predicate,expected_len",
[
([[("x", "==", 0)], [("z", "==", 0)]], 4),
([("x", "==", 0), ("z", "==", 0)], 0),
([("x", "==", 0), ("z", "!=", 0)], 2),
([[("x", "==", 0)], [("z", "==", 0)]], 2),
([("x", "==", 0), ("z", "==", 0)], 0),
([("x", "==", 0), ("z", "!=", 0)], 1),
([("y", "==", "c"), ("x", ">", 8)], 0),
([("y", "==", "c"), ("x", ">=", 5)], 2),
([[("y", "==", "c")], [("x", "<", 3)]], 6),
([("y", "==", "c"), ("x", ">=", 5)], 1),
([[("y", "==", "c")], [("x", "<", 3)]], 5),
([[("x", "not in", (0, 9)), ("z", "not in", (4, 5))]], 6),
([[("y", "==", "c")], [("x", "in", (0, 9)), ("z", "in", (0, 9))]], 4),
([[("x", "==", 0)], [("x", "==", 1)], [("x", "==", 2)]], 3),
([[("x", "==", 0), ("z", "==", 9), ("y", "==", "a")]], 1),
],
)
def test_parquet_read_filtered_complex_predicate(
Expand All @@ -556,7 +557,11 @@ def test_parquet_read_filtered_complex_predicate(
# Generate data
fname = tmpdir.join("filtered_complex_predicate.parquet")
df = pd.DataFrame(
{"x": range(10), "y": list("aabbccddee"), "z": reversed(range(10))}
{
"x": range(10),
"y": list("aabbccddee"),
"z": reversed(range(10)),
}
)
df.to_parquet(fname, row_group_size=2)

Expand Down Expand Up @@ -1954,26 +1959,16 @@ def test_read_parquet_partitioned_filtered(
assert got.dtypes["c"] == "int"
assert_eq(expect, got)

# Filter on non-partitioned column.
# Cannot compare to pandas, since the pyarrow
# backend will filter by row (and cudf can
# only filter by column, for now)
# Filter on non-partitioned column
filters = [("a", "==", 10)]
got = cudf.read_parquet(
read_path,
filters=filters,
row_groups=row_groups,
)
assert len(got) < len(df) and 10 in got["a"]
got = cudf.read_parquet(read_path, filters=filters)
expect = pd.read_parquet(read_path, filters=filters)

# Filter on both kinds of columns
filters = [[("a", "==", 10)], [("c", "==", 1)]]
got = cudf.read_parquet(
read_path,
filters=filters,
row_groups=row_groups,
)
assert len(got) < len(df) and (1 in got["c"] and 10 in got["a"])
got = cudf.read_parquet(read_path, filters=filters)
expect = pd.read_parquet(read_path, filters=filters)
assert_eq(expect, got)


def test_parquet_writer_chunked_metadata(tmpdir, simple_pdf, simple_gdf):
Expand Down
7 changes: 4 additions & 3 deletions python/cudf/cudf/utils/ioutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,12 @@
For other URLs (e.g. starting with "s3://", and "gcs://") the key-value
pairs are forwarded to ``fsspec.open``. Please see ``fsspec`` and
``urllib`` for more details.
filters : list of tuple, list of lists of tuples default None
filters : list of tuple, list of lists of tuples, default None
If not None, specifies a filter predicate used to filter out row groups
using statistics stored for each row group as Parquet metadata. Row groups
that do not match the given filter predicate are not read. The
predicate is expressed in disjunctive normal form (DNF) like
that do not match the given filter predicate are not read. The filters
will also be applied to the rows of the in-memory DataFrame after IO.
The predicate is expressed in disjunctive normal form (DNF) like
`[[('x', '=', 0), ...], ...]`. DNF allows arbitrary boolean logical
combinations of single column predicates. The innermost tuples each
describe a single column predicate. The list of inner predicates is
Expand Down
14 changes: 13 additions & 1 deletion python/dask_cudf/dask_cudf/io/parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@
import cudf
from cudf.core.column import as_column, build_categorical_column
from cudf.io import write_to_dataset
from cudf.io.parquet import _default_open_file_options
from cudf.io.parquet import (
_apply_post_filters,
_default_open_file_options,
_normalize_filters,
)
from cudf.utils.dtypes import cudf_dtype_from_pa_type
from cudf.utils.ioutils import (
_ROW_GROUP_SIZE_BYTES_DEFAULT,
Expand Down Expand Up @@ -69,6 +73,7 @@ def _read_paths(
fs,
columns=None,
row_groups=None,
filters=None,
strings_to_categorical=None,
partitions=None,
partitioning=None,
Expand Down Expand Up @@ -134,6 +139,10 @@ def _read_paths(
else:
raise err

# Apply filters (if any are defined)
filters = _normalize_filters(filters)
df = _apply_post_filters(df, filters)

if partitions and partition_keys is None:

# Use `HivePartitioning` by default
Expand Down Expand Up @@ -183,6 +192,7 @@ def read_partition(
index,
categories=(),
partitions=(),
filters=None,
partitioning=None,
schema=None,
open_file_options=None,
Expand Down Expand Up @@ -255,6 +265,7 @@ def read_partition(
fs,
columns=read_columns,
row_groups=rgs if rgs else None,
filters=filters,
strings_to_categorical=strings_to_cats,
partitions=partitions,
partitioning=partitioning,
Expand All @@ -281,6 +292,7 @@ def read_partition(
fs,
columns=read_columns,
row_groups=rgs if rgs else None,
filters=filters,
strings_to_categorical=strings_to_cats,
partitions=partitions,
partitioning=partitioning,
Expand Down
37 changes: 36 additions & 1 deletion python/dask_cudf/dask_cudf/io/tests/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,41 @@ def test_filters(tmpdir):
assert not len(c)


@pytest.mark.parametrize("numeric", [True, False])
@pytest.mark.parametrize("null", [np.nan, None])
def test_isna_filters(tmpdir, null, numeric):

tmp_path = str(tmpdir)
df = pd.DataFrame(
{
"x": range(10),
"y": list("aabbccddee"),
"i": [0] * 4 + [np.nan] * 2 + [0] * 4,
"j": [""] * 4 + [None] * 2 + [""] * 4,
}
)
ddf = dd.from_pandas(df, npartitions=5)
assert ddf.npartitions == 5
ddf.to_parquet(tmp_path, engine="pyarrow")

# Test "is"
col = "i" if numeric else "j"
filters = [(col, "is", null)]
out = dask_cudf.read_parquet(
tmp_path, filters=filters, split_row_groups=True
)
assert len(out) == 2
assert list(out.x.compute().values) == [4, 5]

# Test "is not"
filters = [(col, "is not", null)]
out = dask_cudf.read_parquet(
tmp_path, filters=filters, split_row_groups=True
)
assert len(out) == 8
assert list(out.x.compute().values) == [0, 1, 2, 3, 6, 7, 8, 9]


def test_filters_at_row_group_level(tmpdir):

tmp_path = str(tmpdir)
Expand All @@ -267,7 +302,7 @@ def test_filters_at_row_group_level(tmpdir):
tmp_path, filters=[("x", "==", 1)], split_row_groups=True
)
assert a.npartitions == 1
assert (a.shape[0] == 2).compute()
assert (a.shape[0] == 1).compute()

ddf.to_parquet(tmp_path, engine="pyarrow", row_group_size=1)

Expand Down

0 comments on commit 1089997

Please sign in to comment.