From 6488c5ed29ce1ec1a4ac500938e5c4537c1aca3a Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 10 May 2023 14:25:32 -0700 Subject: [PATCH 01/10] support row-wise filtering - with test coverage --- python/cudf/cudf/io/parquet.py | 116 +++++++++++++++++- python/cudf/cudf/tests/test_parquet.py | 35 ++---- python/dask_cudf/dask_cudf/io/parquet.py | 6 + .../dask_cudf/io/tests/test_parquet.py | 2 +- 4 files changed, 133 insertions(+), 26 deletions(-) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 3e1a4b1f024..2cc5b2950e2 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -1,6 +1,7 @@ # Copyright (c) 2019-2023, NVIDIA CORPORATION. import math +import operator import shutil import tempfile import warnings @@ -9,6 +10,7 @@ from typing import Dict, List, Optional, Tuple from uuid import uuid4 +import numpy as np import pandas as pd from pyarrow import dataset as ds, parquet as pq @@ -547,7 +549,8 @@ def read_parquet( "for full CPU-based filtering functionality." ) - return _parquet_to_frame( + # Convert parquet data to a cudf.DataFrame + df = _parquet_to_frame( filepaths_or_buffers, engine, *args, @@ -561,6 +564,117 @@ def read_parquet( **kwargs, ) + # Can re-set the index before returning if we filter + # out rows from a DataFrame with a default RangeIndex + # (to reduce memory usage) + reset_index = filters and ( + isinstance(df.index, cudf.RangeIndex) + and df.index.name is None + and df.index.start == 0 + and df.index.step == 1 + ) + + # Apply filters (if any are defined) + df = _apply_dnf_filters(df, filters) + + # Return final cudf.DataFrame + return df.reset_index(drop=True) if reset_index else df + + +def _apply_dnf_filters(df, filters): + # Apply DNF filters to a DataFrame + # + # Disjunctive normal form (DNF) means that the inner-most + # tuple describes a single column predicate. These inner + # predicates are combined with an AND conjunction into a + # larger predicate. The outer-most list then combines all + # of the combined filters with an OR disjunction. + + if not filters: + # No filters to apply + return df + + _comparisons = { + "==": operator.eq, + "!=": operator.ne, + "<": operator.lt, + "<=": operator.le, + ">": operator.gt, + ">=": operator.ge, + } + + try: + # Disjunction loop + # + # All elements of `disjunctions` shall be combined with + # an `OR` disjunction (operator.or_) + disjunctions = [] + other = filters.copy() + while other: + + # Conjunction loop + # + # All elements of `conjunctions` shall be combined with + # an `AND` conjunction (operator.and_) + conjunctions = [] + comparisons, *other = ( + other if isinstance(other[0], list) else [other] + ) + for (column, op, value) in comparisons: + + # Inner comparison loop + # + # `op` is expected to be the string representation + # of a comparison operator (e.g. "==") + if op == "in": + # Special case: "in" + if not isinstance(value, (list, set, tuple)): + raise TypeError( + "Value of 'in' filter must be a " + "list, set, or tuple." + ) + if len(value) == 1: + conjunctions.append(operator.eq(df[column], value[0])) + else: + conjunctions.append( + operator.or_( + *[operator.eq(df[column], v) for v in value] + ) + ) + elif op in ("is", "is not"): + # Special case: "is" or "is not" + if value not in (np.nan, None): + raise TypeError( + "Value of 'is' or 'is not' filter " + "must be np.nan or None." + ) + conjunctions.append( + df[column].isna() if op == "is" else ~df[column].isna() + ) + else: + # Conventional comparison operator + conjunctions.append(_comparisons[op](df[column], value)) + + disjunctions.append( + operator.and_(*conjunctions) + if len(conjunctions) > 1 + else conjunctions[0] + ) + + return df[ + operator.or_(*disjunctions) + if len(disjunctions) > 1 + else disjunctions[0] + ] + + except (KeyError, TypeError): + + # Unsupported op or value + warnings.warn( + f"Row-wise filtering failed in read_parquet for {filters}" + ) + return df + @_cudf_nvtx_annotate def _parquet_to_frame( diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 0ab5d35f9f8..3939746e759 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -528,9 +528,7 @@ def test_parquet_read_filtered_multiple_files(tmpdir): ) assert_eq( filtered_df, - cudf.DataFrame( - {"x": [2, 3, 2, 3], "y": list("bbcc")}, index=[2, 3, 2, 3] - ), + cudf.DataFrame({"x": [2, 2], "y": list("bc")}, index=[2, 2]), ) @@ -541,13 +539,12 @@ def test_parquet_read_filtered_multiple_files(tmpdir): @pytest.mark.parametrize( "predicate,expected_len", [ - ([[("x", "==", 0)], [("z", "==", 0)]], 4), - ([("x", "==", 0), ("z", "==", 0)], 0), - ([("x", "==", 0), ("z", "!=", 0)], 2), + ([[("x", "==", 0)], [("z", "==", 0)]], 2), ([("x", "==", 0), ("z", "==", 0)], 0), + ([("x", "==", 0), ("z", "!=", 0)], 1), ([("y", "==", "c"), ("x", ">", 8)], 0), - ([("y", "==", "c"), ("x", ">=", 5)], 2), - ([[("y", "==", "c")], [("x", "<", 3)]], 6), + ([("y", "==", "c"), ("x", ">=", 5)], 1), + ([[("y", "==", "c")], [("x", "<", 3)]], 5), ], ) def test_parquet_read_filtered_complex_predicate( @@ -1954,26 +1951,16 @@ def test_read_parquet_partitioned_filtered( assert got.dtypes["c"] == "int" assert_eq(expect, got) - # Filter on non-partitioned column. - # Cannot compare to pandas, since the pyarrow - # backend will filter by row (and cudf can - # only filter by column, for now) + # Filter on non-partitioned column filters = [("a", "==", 10)] - got = cudf.read_parquet( - read_path, - filters=filters, - row_groups=row_groups, - ) - assert len(got) < len(df) and 10 in got["a"] + got = cudf.read_parquet(read_path, filters=filters) + expect = pd.read_parquet(read_path, filters=filters) # Filter on both kinds of columns filters = [[("a", "==", 10)], [("c", "==", 1)]] - got = cudf.read_parquet( - read_path, - filters=filters, - row_groups=row_groups, - ) - assert len(got) < len(df) and (1 in got["c"] and 10 in got["a"]) + got = cudf.read_parquet(read_path, filters=filters) + expect = pd.read_parquet(read_path, filters=filters) + assert_eq(expect, got) def test_parquet_writer_chunked_metadata(tmpdir, simple_pdf, simple_gdf): diff --git a/python/dask_cudf/dask_cudf/io/parquet.py b/python/dask_cudf/dask_cudf/io/parquet.py index f19c373150d..6e1ebaed6a0 100644 --- a/python/dask_cudf/dask_cudf/io/parquet.py +++ b/python/dask_cudf/dask_cudf/io/parquet.py @@ -66,6 +66,7 @@ def _read_paths( fs, columns=None, row_groups=None, + filters=None, strings_to_categorical=None, partitions=None, partitioning=None, @@ -102,6 +103,7 @@ def _read_paths( engine="cudf", columns=columns, row_groups=row_groups if row_groups else None, + filters=filters, strings_to_categorical=strings_to_categorical, dataset_kwargs=dataset_kwargs, categorical_partitions=False, @@ -120,6 +122,7 @@ def _read_paths( row_groups=row_groups[i] if row_groups else None, + filters=filters, strings_to_categorical=strings_to_categorical, dataset_kwargs=dataset_kwargs, categorical_partitions=False, @@ -180,6 +183,7 @@ def read_partition( index, categories=(), partitions=(), + filters=None, partitioning=None, schema=None, open_file_options=None, @@ -252,6 +256,7 @@ def read_partition( fs, columns=read_columns, row_groups=rgs if rgs else None, + filters=filters, strings_to_categorical=strings_to_cats, partitions=partitions, partitioning=partitioning, @@ -278,6 +283,7 @@ def read_partition( fs, columns=read_columns, row_groups=rgs if rgs else None, + filters=filters, strings_to_categorical=strings_to_cats, partitions=partitions, partitioning=partitioning, diff --git a/python/dask_cudf/dask_cudf/io/tests/test_parquet.py b/python/dask_cudf/dask_cudf/io/tests/test_parquet.py index f5ae9706fde..7ba004ae137 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_parquet.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_parquet.py @@ -267,7 +267,7 @@ def test_filters_at_row_group_level(tmpdir): tmp_path, filters=[("x", "==", 1)], split_row_groups=True ) assert a.npartitions == 1 - assert (a.shape[0] == 2).compute() + assert (a.shape[0] == 1).compute() ddf.to_parquet(tmp_path, engine="pyarrow", row_group_size=1) From 6932869703d6e91446db630f66df62fb8db4acd9 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Wed, 10 May 2023 15:12:30 -0700 Subject: [PATCH 02/10] add post_filters --- python/cudf/cudf/io/parquet.py | 12 +++++---- python/cudf/cudf/tests/test_parquet.py | 33 +++++++++++++++++++++++- python/cudf/cudf/utils/ioutils.py | 6 ++++- python/dask_cudf/dask_cudf/io/parquet.py | 4 +-- 4 files changed, 46 insertions(+), 9 deletions(-) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 2cc5b2950e2..35dc44ea72a 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -441,6 +441,7 @@ def read_parquet( open_file_options=None, bytes_per_thread=None, dataset_kwargs=None, + post_filters=None, *args, **kwargs, ): @@ -567,22 +568,23 @@ def read_parquet( # Can re-set the index before returning if we filter # out rows from a DataFrame with a default RangeIndex # (to reduce memory usage) - reset_index = filters and ( + post_filters = post_filters or filters + reset_index = post_filters and ( isinstance(df.index, cudf.RangeIndex) and df.index.name is None and df.index.start == 0 and df.index.step == 1 ) - # Apply filters (if any are defined) - df = _apply_dnf_filters(df, filters) + # Apply post_filters (if any are defined) + df = _apply_post_filters(df, post_filters) # Return final cudf.DataFrame return df.reset_index(drop=True) if reset_index else df -def _apply_dnf_filters(df, filters): - # Apply DNF filters to a DataFrame +def _apply_post_filters(df, filters): + # Apply DNF filters to an in-memory DataFrame # # Disjunctive normal form (DNF) means that the inner-most # tuple describes a single column predicate. These inner diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 3939746e759..854e75f6bbd 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -545,6 +545,7 @@ def test_parquet_read_filtered_multiple_files(tmpdir): ([("y", "==", "c"), ("x", ">", 8)], 0), ([("y", "==", "c"), ("x", ">=", 5)], 1), ([[("y", "==", "c")], [("x", "<", 3)]], 5), + ([[("y", "==", "c")], [("x", "in", (0, 9))]], 4), ], ) def test_parquet_read_filtered_complex_predicate( @@ -553,7 +554,13 @@ def test_parquet_read_filtered_complex_predicate( # Generate data fname = tmpdir.join("filtered_complex_predicate.parquet") df = pd.DataFrame( - {"x": range(10), "y": list("aabbccddee"), "z": reversed(range(10))} + { + "x": range(10), + "y": list("aabbccddee"), + "z": reversed(range(10)), + "j": [0] * 4 + [np.nan] * 2 + [0] * 4, + "k": [""] * 4 + [None] * 2 + [""] * 4, + } ) df.to_parquet(fname, row_group_size=2) @@ -563,6 +570,30 @@ def test_parquet_read_filtered_complex_predicate( assert_eq(len(df_filtered), expected_len) +@pytest.mark.parametrize( + "predicate,expected_len", + [ + ([[("j", "is not", np.nan)], [("i", "<", 3)]], 8), + ([("k", "is", None)], 2), + ], +) +def test_parquet_post_filters(tmpdir, predicate, expected_len): + # Check that "is" and "is not" are supported + # as `post_filters` (even though they are not + # supported by pyarrow) + fname = tmpdir.join("post_filters.parquet") + df = pd.DataFrame( + { + "i": range(10), + "j": [0] * 4 + [np.nan] * 2 + [0] * 4, + "k": [""] * 4 + [None] * 2 + [""] * 4, + } + ) + df.to_parquet(fname, row_group_size=2) + df_filtered = cudf.read_parquet(fname, post_filters=predicate) + assert_eq(len(df_filtered), expected_len) + + @pytest.mark.parametrize("row_group_size", [1, 5, 100]) def test_parquet_read_row_groups(tmpdir, pdf, row_group_size): if len(pdf) > 100: diff --git a/python/cudf/cudf/utils/ioutils.py b/python/cudf/cudf/utils/ioutils.py index bf51b360fec..8b8fbb00482 100644 --- a/python/cudf/cudf/utils/ioutils.py +++ b/python/cudf/cudf/utils/ioutils.py @@ -147,7 +147,7 @@ For other URLs (e.g. starting with "s3://", and "gcs://") the key-value pairs are forwarded to ``fsspec.open``. Please see ``fsspec`` and ``urllib`` for more details. -filters : list of tuple, list of lists of tuples default None +filters : list of tuple, list of lists of tuples, default None If not None, specifies a filter predicate used to filter out row groups using statistics stored for each row group as Parquet metadata. Row groups that do not match the given filter predicate are not read. The @@ -161,6 +161,10 @@ as a list of tuples. This form is interpreted as a single conjunction. To express OR in predicates, one must use the (preferred) notation of list of lists of tuples. +post_filters : list of tuple, list of lists of tuples, default None + Row-wise filters to be applied to the in-memory `DataFrame` after IO + is performed. If `None` (the default), `post_filters` will be set equal + to the value of `filters`. row_groups : int, or list, or a list of lists default None If not None, specifies, for each input file, which row groups to read. If reading multiple inputs, a list of lists should be passed, one list diff --git a/python/dask_cudf/dask_cudf/io/parquet.py b/python/dask_cudf/dask_cudf/io/parquet.py index 6e1ebaed6a0..896ce10a02a 100644 --- a/python/dask_cudf/dask_cudf/io/parquet.py +++ b/python/dask_cudf/dask_cudf/io/parquet.py @@ -103,7 +103,7 @@ def _read_paths( engine="cudf", columns=columns, row_groups=row_groups if row_groups else None, - filters=filters, + post_filters=filters, strings_to_categorical=strings_to_categorical, dataset_kwargs=dataset_kwargs, categorical_partitions=False, @@ -122,7 +122,7 @@ def _read_paths( row_groups=row_groups[i] if row_groups else None, - filters=filters, + post_filters=filters, strings_to_categorical=strings_to_categorical, dataset_kwargs=dataset_kwargs, categorical_partitions=False, From dcddab07537f404af0882b2db01ba75f01e5d66e Mon Sep 17 00:00:00 2001 From: rjzamora Date: Thu, 11 May 2023 20:38:26 -0700 Subject: [PATCH 03/10] revise for code review --- python/cudf/cudf/io/parquet.py | 117 +++++++++-------------- python/cudf/cudf/tests/test_parquet.py | 30 +----- python/cudf/cudf/utils/ioutils.py | 9 +- python/dask_cudf/dask_cudf/io/parquet.py | 7 +- 4 files changed, 53 insertions(+), 110 deletions(-) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 35dc44ea72a..e357f27c914 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -7,6 +7,7 @@ import warnings from collections import defaultdict from contextlib import ExitStack +from functools import partial, reduce from typing import Dict, List, Optional, Tuple from uuid import uuid4 @@ -441,7 +442,6 @@ def read_parquet( open_file_options=None, bytes_per_thread=None, dataset_kwargs=None, - post_filters=None, *args, **kwargs, ): @@ -565,22 +565,24 @@ def read_parquet( **kwargs, ) - # Can re-set the index before returning if we filter - # out rows from a DataFrame with a default RangeIndex - # (to reduce memory usage) - post_filters = post_filters or filters - reset_index = post_filters and ( - isinstance(df.index, cudf.RangeIndex) - and df.index.name is None - and df.index.start == 0 - and df.index.step == 1 - ) + # Apply filters row-wise (if any are defined), and return + return _apply_post_filters(df, filters) - # Apply post_filters (if any are defined) - df = _apply_post_filters(df, post_filters) - # Return final cudf.DataFrame - return df.reset_index(drop=True) if reset_index else df +def _handle_in(column, value): + if not isinstance(value, (list, set, tuple)): + raise TypeError( + "Value of 'in' filter must be a " "list, set, or tuple." + ) + return reduce(operator.or_, (operator.eq(column, v) for v in value)) + + +def _handle_is(column, value, *, negate): + if value not in {np.nan, None}: + raise TypeError( + "Value of 'is' or 'is not' filter " "must be np.nan or None." + ) + return ~column.isna() if negate else column.isna() def _apply_post_filters(df, filters): @@ -596,82 +598,49 @@ def _apply_post_filters(df, filters): # No filters to apply return df - _comparisons = { + handlers = { "==": operator.eq, "!=": operator.ne, "<": operator.lt, "<=": operator.le, ">": operator.gt, ">=": operator.ge, + "in": _handle_in, + "is": partial(_handle_is, negate=False), + "is not": partial(_handle_is, negate=True), } + # Can re-set the index before returning if we filter + # out rows from a DataFrame with a default RangeIndex + # (to reduce memory usage) + reset_index = filters and ( + isinstance(df.index, cudf.RangeIndex) + and df.index.name is None + and df.index.start == 0 + and df.index.step == 1 + ) + try: # Disjunction loop # # All elements of `disjunctions` shall be combined with # an `OR` disjunction (operator.or_) disjunctions = [] - other = filters.copy() - while other: - - # Conjunction loop - # - # All elements of `conjunctions` shall be combined with - # an `AND` conjunction (operator.and_) - conjunctions = [] - comparisons, *other = ( - other if isinstance(other[0], list) else [other] - ) - for (column, op, value) in comparisons: - - # Inner comparison loop - # - # `op` is expected to be the string representation - # of a comparison operator (e.g. "==") - if op == "in": - # Special case: "in" - if not isinstance(value, (list, set, tuple)): - raise TypeError( - "Value of 'in' filter must be a " - "list, set, or tuple." - ) - if len(value) == 1: - conjunctions.append(operator.eq(df[column], value[0])) - else: - conjunctions.append( - operator.or_( - *[operator.eq(df[column], v) for v in value] - ) - ) - elif op in ("is", "is not"): - # Special case: "is" or "is not" - if value not in (np.nan, None): - raise TypeError( - "Value of 'is' or 'is not' filter " - "must be np.nan or None." - ) - conjunctions.append( - df[column].isna() if op == "is" else ~df[column].isna() - ) - else: - # Conventional comparison operator - conjunctions.append(_comparisons[op](df[column], value)) - - disjunctions.append( - operator.and_(*conjunctions) - if len(conjunctions) > 1 - else conjunctions[0] + for expr in filters if isinstance(filters[0], list) else [filters]: + conjunction = reduce( + operator.and_, + ( + handlers[op](df[column], value) + for (column, op, value) in expr + ), ) + disjunctions.append(conjunction) - return df[ - operator.or_(*disjunctions) - if len(disjunctions) > 1 - else disjunctions[0] - ] - + selection = reduce(operator.or_, disjunctions) + if reset_index: + return df[selection].reset_index(drop=True) + return df[selection] except (KeyError, TypeError): - - # Unsupported op or value warnings.warn( f"Row-wise filtering failed in read_parquet for {filters}" ) diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 854e75f6bbd..30c0ddd7067 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -545,7 +545,9 @@ def test_parquet_read_filtered_multiple_files(tmpdir): ([("y", "==", "c"), ("x", ">", 8)], 0), ([("y", "==", "c"), ("x", ">=", 5)], 1), ([[("y", "==", "c")], [("x", "<", 3)]], 5), - ([[("y", "==", "c")], [("x", "in", (0, 9))]], 4), + ([[("y", "==", "c")], [("x", "in", (0, 9)), ("z", "in", (0, 9))]], 4), + ([[("x", "==", 0)], [("x", "==", 1)], [("x", "==", 2)]], 3), + ([[("x", "==", 0), ("z", "==", 9), ("y", "==", "a")]], 1), ], ) def test_parquet_read_filtered_complex_predicate( @@ -558,8 +560,6 @@ def test_parquet_read_filtered_complex_predicate( "x": range(10), "y": list("aabbccddee"), "z": reversed(range(10)), - "j": [0] * 4 + [np.nan] * 2 + [0] * 4, - "k": [""] * 4 + [None] * 2 + [""] * 4, } ) df.to_parquet(fname, row_group_size=2) @@ -570,30 +570,6 @@ def test_parquet_read_filtered_complex_predicate( assert_eq(len(df_filtered), expected_len) -@pytest.mark.parametrize( - "predicate,expected_len", - [ - ([[("j", "is not", np.nan)], [("i", "<", 3)]], 8), - ([("k", "is", None)], 2), - ], -) -def test_parquet_post_filters(tmpdir, predicate, expected_len): - # Check that "is" and "is not" are supported - # as `post_filters` (even though they are not - # supported by pyarrow) - fname = tmpdir.join("post_filters.parquet") - df = pd.DataFrame( - { - "i": range(10), - "j": [0] * 4 + [np.nan] * 2 + [0] * 4, - "k": [""] * 4 + [None] * 2 + [""] * 4, - } - ) - df.to_parquet(fname, row_group_size=2) - df_filtered = cudf.read_parquet(fname, post_filters=predicate) - assert_eq(len(df_filtered), expected_len) - - @pytest.mark.parametrize("row_group_size", [1, 5, 100]) def test_parquet_read_row_groups(tmpdir, pdf, row_group_size): if len(pdf) > 100: diff --git a/python/cudf/cudf/utils/ioutils.py b/python/cudf/cudf/utils/ioutils.py index 8b8fbb00482..bf9dc226d11 100644 --- a/python/cudf/cudf/utils/ioutils.py +++ b/python/cudf/cudf/utils/ioutils.py @@ -150,8 +150,9 @@ filters : list of tuple, list of lists of tuples, default None If not None, specifies a filter predicate used to filter out row groups using statistics stored for each row group as Parquet metadata. Row groups - that do not match the given filter predicate are not read. The - predicate is expressed in disjunctive normal form (DNF) like + that do not match the given filter predicate are not read. The filters + will also be applied to the rows of the in-memory DataFrame after IO. + The predicate is expressed in disjunctive normal form (DNF) like `[[('x', '=', 0), ...], ...]`. DNF allows arbitrary boolean logical combinations of single column predicates. The innermost tuples each describe a single column predicate. The list of inner predicates is @@ -161,10 +162,6 @@ as a list of tuples. This form is interpreted as a single conjunction. To express OR in predicates, one must use the (preferred) notation of list of lists of tuples. -post_filters : list of tuple, list of lists of tuples, default None - Row-wise filters to be applied to the in-memory `DataFrame` after IO - is performed. If `None` (the default), `post_filters` will be set equal - to the value of `filters`. row_groups : int, or list, or a list of lists default None If not None, specifies, for each input file, which row groups to read. If reading multiple inputs, a list of lists should be passed, one list diff --git a/python/dask_cudf/dask_cudf/io/parquet.py b/python/dask_cudf/dask_cudf/io/parquet.py index 896ce10a02a..3ffb71f0d2e 100644 --- a/python/dask_cudf/dask_cudf/io/parquet.py +++ b/python/dask_cudf/dask_cudf/io/parquet.py @@ -20,7 +20,7 @@ import cudf from cudf.core.column import as_column, build_categorical_column from cudf.io import write_to_dataset -from cudf.io.parquet import _default_open_file_options +from cudf.io.parquet import _apply_post_filters, _default_open_file_options from cudf.utils.dtypes import cudf_dtype_from_pa_type from cudf.utils.ioutils import ( _ROW_GROUP_SIZE_BYTES_DEFAULT, @@ -103,7 +103,6 @@ def _read_paths( engine="cudf", columns=columns, row_groups=row_groups if row_groups else None, - post_filters=filters, strings_to_categorical=strings_to_categorical, dataset_kwargs=dataset_kwargs, categorical_partitions=False, @@ -122,7 +121,6 @@ def _read_paths( row_groups=row_groups[i] if row_groups else None, - post_filters=filters, strings_to_categorical=strings_to_categorical, dataset_kwargs=dataset_kwargs, categorical_partitions=False, @@ -134,6 +132,9 @@ def _read_paths( else: raise err + # Apply filters (if any are defined) + df = _apply_post_filters(df, filters) + if partitions and partition_keys is None: # Use `HivePartitioning` by default From f833445cc9aaea5848376f6b74f6d1bec44ef6c9 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Fri, 12 May 2023 11:09:47 -0700 Subject: [PATCH 04/10] add not in support --- python/cudf/cudf/io/parquet.py | 14 +++++--- python/cudf/cudf/tests/test_parquet.py | 1 + .../dask_cudf/io/tests/test_parquet.py | 35 +++++++++++++++++++ 3 files changed, 45 insertions(+), 5 deletions(-) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index e357f27c914..40b2fef164f 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -569,18 +569,21 @@ def read_parquet( return _apply_post_filters(df, filters) -def _handle_in(column, value): +def _handle_in(column, value, *, negate): if not isinstance(value, (list, set, tuple)): raise TypeError( - "Value of 'in' filter must be a " "list, set, or tuple." + "Value of 'in' or 'not in' filter must be a list, set, or tuple." ) - return reduce(operator.or_, (operator.eq(column, v) for v in value)) + if negate: + return reduce(operator.and_, (operator.ne(column, v) for v in value)) + else: + return reduce(operator.or_, (operator.eq(column, v) for v in value)) def _handle_is(column, value, *, negate): if value not in {np.nan, None}: raise TypeError( - "Value of 'is' or 'is not' filter " "must be np.nan or None." + "Value of 'is' or 'is not' filter must be np.nan or None." ) return ~column.isna() if negate else column.isna() @@ -605,7 +608,8 @@ def _apply_post_filters(df, filters): "<=": operator.le, ">": operator.gt, ">=": operator.ge, - "in": _handle_in, + "in": partial(_handle_in, negate=False), + "not in": partial(_handle_in, negate=True), "is": partial(_handle_is, negate=False), "is not": partial(_handle_is, negate=True), } diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 30c0ddd7067..7aeb2afc2c7 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -545,6 +545,7 @@ def test_parquet_read_filtered_multiple_files(tmpdir): ([("y", "==", "c"), ("x", ">", 8)], 0), ([("y", "==", "c"), ("x", ">=", 5)], 1), ([[("y", "==", "c")], [("x", "<", 3)]], 5), + ([[("x", "not in", (0, 9)), ("z", "not in", (4, 5))]], 6), ([[("y", "==", "c")], [("x", "in", (0, 9)), ("z", "in", (0, 9))]], 4), ([[("x", "==", 0)], [("x", "==", 1)], [("x", "==", 2)]], 3), ([[("x", "==", 0), ("z", "==", 9), ("y", "==", "a")]], 1), diff --git a/python/dask_cudf/dask_cudf/io/tests/test_parquet.py b/python/dask_cudf/dask_cudf/io/tests/test_parquet.py index 7ba004ae137..8e80aad67d1 100644 --- a/python/dask_cudf/dask_cudf/io/tests/test_parquet.py +++ b/python/dask_cudf/dask_cudf/io/tests/test_parquet.py @@ -254,6 +254,41 @@ def test_filters(tmpdir): assert not len(c) +@pytest.mark.parametrize("numeric", [True, False]) +@pytest.mark.parametrize("null", [np.nan, None]) +def test_isna_filters(tmpdir, null, numeric): + + tmp_path = str(tmpdir) + df = pd.DataFrame( + { + "x": range(10), + "y": list("aabbccddee"), + "i": [0] * 4 + [np.nan] * 2 + [0] * 4, + "j": [""] * 4 + [None] * 2 + [""] * 4, + } + ) + ddf = dd.from_pandas(df, npartitions=5) + assert ddf.npartitions == 5 + ddf.to_parquet(tmp_path, engine="pyarrow") + + # Test "is" + col = "i" if numeric else "j" + filters = [(col, "is", null)] + out = dask_cudf.read_parquet( + tmp_path, filters=filters, split_row_groups=True + ) + assert len(out) == 2 + assert list(out.x.compute().values) == [4, 5] + + # Test "is not" + filters = [(col, "is not", null)] + out = dask_cudf.read_parquet( + tmp_path, filters=filters, split_row_groups=True + ) + assert len(out) == 8 + assert list(out.x.compute().values) == [0, 1, 2, 3, 6, 7, 8, 9] + + def test_filters_at_row_group_level(tmpdir): tmp_path = str(tmpdir) From b76bc938953eb8b0fa3695c300b56bddd39667b7 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Fri, 12 May 2023 12:12:30 -0700 Subject: [PATCH 05/10] basic type hints --- python/cudf/cudf/io/parquet.py | 23 ++++++++++++++++++----- python/dask_cudf/dask_cudf/io/parquet.py | 2 ++ 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 40b2fef164f..b04612136e9 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -1,4 +1,5 @@ # Copyright (c) 2019-2023, NVIDIA CORPORATION. +from __future__ import annotations import math import operator @@ -8,7 +9,7 @@ from collections import defaultdict from contextlib import ExitStack from functools import partial, reduce -from typing import Dict, List, Optional, Tuple +from typing import Callable, Dict, List, Optional, Tuple from uuid import uuid4 import numpy as np @@ -566,6 +567,8 @@ def read_parquet( ) # Apply filters row-wise (if any are defined), and return + if filters is not None and not isinstance(filters, list): + raise TypeError(f"filters must be list, got {type(filters)}.") return _apply_post_filters(df, filters) @@ -588,7 +591,7 @@ def _handle_is(column, value, *, negate): return ~column.isna() if negate else column.isna() -def _apply_post_filters(df, filters): +def _apply_post_filters(df, filters: list | None): # Apply DNF filters to an in-memory DataFrame # # Disjunctive normal form (DNF) means that the inner-most @@ -601,7 +604,13 @@ def _apply_post_filters(df, filters): # No filters to apply return df - handlers = { + if filters is not None and ( + not isinstance(filters, list) + or not isinstance(filters[0], (list, tuple)) + ): + raise TypeError("filters must be List[Tuple] or ListList[[Tuple]]") + + handlers: Dict[str, Callable] = { "==": operator.eq, "!=": operator.ne, "<": operator.lt, @@ -617,20 +626,24 @@ def _apply_post_filters(df, filters): # Can re-set the index before returning if we filter # out rows from a DataFrame with a default RangeIndex # (to reduce memory usage) - reset_index = filters and ( + reset_index = ( isinstance(df.index, cudf.RangeIndex) and df.index.name is None and df.index.start == 0 and df.index.step == 1 ) + # Make sure we have List[List[Tuple]]. If filters + # is List[Tuple], then we only have a conjunction + expressions = filters if isinstance(filters[0], list) else [filters] + try: # Disjunction loop # # All elements of `disjunctions` shall be combined with # an `OR` disjunction (operator.or_) disjunctions = [] - for expr in filters if isinstance(filters[0], list) else [filters]: + for expr in expressions: conjunction = reduce( operator.and_, ( diff --git a/python/dask_cudf/dask_cudf/io/parquet.py b/python/dask_cudf/dask_cudf/io/parquet.py index 3ffb71f0d2e..cb2f97eb613 100644 --- a/python/dask_cudf/dask_cudf/io/parquet.py +++ b/python/dask_cudf/dask_cudf/io/parquet.py @@ -133,6 +133,8 @@ def _read_paths( raise err # Apply filters (if any are defined) + if filters is not None and not isinstance(filters, list): + raise TypeError(f"filters must be a list, got {type(filters)}.") df = _apply_post_filters(df, filters) if partitions and partition_keys is None: From 81fc31fe80b9b34ad3f3d7a29014545bd2539293 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Fri, 12 May 2023 14:15:00 -0700 Subject: [PATCH 06/10] better type hints --- python/cudf/cudf/io/parquet.py | 72 ++++++++++++++---------- python/dask_cudf/dask_cudf/io/parquet.py | 2 - 2 files changed, 43 insertions(+), 31 deletions(-) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index b04612136e9..32455405ebb 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -567,11 +567,37 @@ def read_parquet( ) # Apply filters row-wise (if any are defined), and return - if filters is not None and not isinstance(filters, list): - raise TypeError(f"filters must be list, got {type(filters)}.") return _apply_post_filters(df, filters) +def _normalize_filters(filters: list | None) -> List[List[tuple]] | None: + # Utility to normalize and validate the `filters` + # argument to `read_parquet` + if filters is not None: + msg = ( + f"filters must be None, or non-empty List[Tuple] " + f"or List[List[Tuple]]. Got {filters}" + ) + if not filters or not isinstance(filters, list): + raise TypeError(msg) + + def _validate_predicate(item): + if not isinstance(item, tuple) or len(item) != 3: + raise TypeError( + f"Predicate must be Tuple[str, str, Any], " + f"got {predicate}." + ) + + filters = filters if isinstance(filters[0], list) else [filters] + for conjunction in filters: + if not conjunction or not isinstance(conjunction, list): + raise TypeError(msg) + for predicate in conjunction: + _validate_predicate(predicate) + + return filters + + def _handle_in(column, value, *, negate): if not isinstance(value, (list, set, tuple)): raise TypeError( @@ -591,7 +617,7 @@ def _handle_is(column, value, *, negate): return ~column.isna() if negate else column.isna() -def _apply_post_filters(df, filters: list | None): +def _apply_post_filters(df: cudf.DataFrame, filters: list | None): # Apply DNF filters to an in-memory DataFrame # # Disjunctive normal form (DNF) means that the inner-most @@ -600,16 +626,11 @@ def _apply_post_filters(df, filters: list | None): # larger predicate. The outer-most list then combines all # of the combined filters with an OR disjunction. + filters = _normalize_filters(filters) if not filters: # No filters to apply return df - if filters is not None and ( - not isinstance(filters, list) - or not isinstance(filters[0], (list, tuple)) - ): - raise TypeError("filters must be List[Tuple] or ListList[[Tuple]]") - handlers: Dict[str, Callable] = { "==": operator.eq, "!=": operator.ne, @@ -633,27 +654,20 @@ def _apply_post_filters(df, filters: list | None): and df.index.step == 1 ) - # Make sure we have List[List[Tuple]]. If filters - # is List[Tuple], then we only have a conjunction - expressions = filters if isinstance(filters[0], list) else [filters] - try: - # Disjunction loop - # - # All elements of `disjunctions` shall be combined with - # an `OR` disjunction (operator.or_) - disjunctions = [] - for expr in expressions: - conjunction = reduce( - operator.and_, - ( - handlers[op](df[column], value) - for (column, op, value) in expr - ), - ) - disjunctions.append(conjunction) - - selection = reduce(operator.or_, disjunctions) + selection: cudf.Series = reduce( + operator.or_, + ( + reduce( + operator.and_, + ( + handlers[op](df[column], value) + for (column, op, value) in expr + ), + ) + for expr in filters + ), + ) if reset_index: return df[selection].reset_index(drop=True) return df[selection] diff --git a/python/dask_cudf/dask_cudf/io/parquet.py b/python/dask_cudf/dask_cudf/io/parquet.py index cb2f97eb613..3ffb71f0d2e 100644 --- a/python/dask_cudf/dask_cudf/io/parquet.py +++ b/python/dask_cudf/dask_cudf/io/parquet.py @@ -133,8 +133,6 @@ def _read_paths( raise err # Apply filters (if any are defined) - if filters is not None and not isinstance(filters, list): - raise TypeError(f"filters must be a list, got {type(filters)}.") df = _apply_post_filters(df, filters) if partitions and partition_keys is None: From 5c33b86b1d1a31647227a6837860815434a6072f Mon Sep 17 00:00:00 2001 From: rjzamora Date: Fri, 12 May 2023 14:22:15 -0700 Subject: [PATCH 07/10] move filter normalization --- python/cudf/cudf/io/parquet.py | 8 ++++---- python/dask_cudf/dask_cudf/io/parquet.py | 7 ++++++- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 32455405ebb..37ba862eadf 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -485,6 +485,9 @@ def read_parquet( path_or_data=filepath_or_buffer, storage_options=storage_options ) + # Normalize filters + filters = _normalize_filters(filters) + # Use pyarrow dataset to detect/process directory-partitioned # data and apply filters. Note that we can only support partitioned # data and filtering if the input is a single directory or list of @@ -505,8 +508,6 @@ def read_parquet( categorical_partitions=categorical_partitions, dataset_kwargs=dataset_kwargs, ) - elif filters is not None: - raise ValueError("cudf cannot apply filters to open file objects.") filepath_or_buffer = paths if paths else filepath_or_buffer filepaths_or_buffers = [] @@ -617,7 +618,7 @@ def _handle_is(column, value, *, negate): return ~column.isna() if negate else column.isna() -def _apply_post_filters(df: cudf.DataFrame, filters: list | None): +def _apply_post_filters(df: cudf.DataFrame, filters: List[List[tuple]] | None): # Apply DNF filters to an in-memory DataFrame # # Disjunctive normal form (DNF) means that the inner-most @@ -626,7 +627,6 @@ def _apply_post_filters(df: cudf.DataFrame, filters: list | None): # larger predicate. The outer-most list then combines all # of the combined filters with an OR disjunction. - filters = _normalize_filters(filters) if not filters: # No filters to apply return df diff --git a/python/dask_cudf/dask_cudf/io/parquet.py b/python/dask_cudf/dask_cudf/io/parquet.py index 3ffb71f0d2e..90d75b1249f 100644 --- a/python/dask_cudf/dask_cudf/io/parquet.py +++ b/python/dask_cudf/dask_cudf/io/parquet.py @@ -20,7 +20,11 @@ import cudf from cudf.core.column import as_column, build_categorical_column from cudf.io import write_to_dataset -from cudf.io.parquet import _apply_post_filters, _default_open_file_options +from cudf.io.parquet import ( + _apply_post_filters, + _default_open_file_options, + _normalize_filters, +) from cudf.utils.dtypes import cudf_dtype_from_pa_type from cudf.utils.ioutils import ( _ROW_GROUP_SIZE_BYTES_DEFAULT, @@ -133,6 +137,7 @@ def _read_paths( raise err # Apply filters (if any are defined) + filters = _normalize_filters(filters) df = _apply_post_filters(df, filters) if partitions and partition_keys is None: From 2f9ea8faff7c8343f7a19cf0923bbbad360d3687 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Mon, 15 May 2023 15:02:48 -0700 Subject: [PATCH 08/10] drop operator --- python/cudf/cudf/io/parquet.py | 90 ++++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 38 deletions(-) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 37ba862eadf..f6a24114b7d 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -2,7 +2,6 @@ from __future__ import annotations import math -import operator import shutil import tempfile import warnings @@ -485,7 +484,7 @@ def read_parquet( path_or_data=filepath_or_buffer, storage_options=storage_options ) - # Normalize filters + # Normalize and validate filters filters = _normalize_filters(filters) # Use pyarrow dataset to detect/process directory-partitioned @@ -574,12 +573,12 @@ def read_parquet( def _normalize_filters(filters: list | None) -> List[List[tuple]] | None: # Utility to normalize and validate the `filters` # argument to `read_parquet` - if filters is not None: + if filters: msg = ( f"filters must be None, or non-empty List[Tuple] " f"or List[List[Tuple]]. Got {filters}" ) - if not filters or not isinstance(filters, list): + if not isinstance(filters, list): raise TypeError(msg) def _validate_predicate(item): @@ -596,48 +595,63 @@ def _validate_predicate(item): for predicate in conjunction: _validate_predicate(predicate) - return filters - - -def _handle_in(column, value, *, negate): - if not isinstance(value, (list, set, tuple)): - raise TypeError( - "Value of 'in' or 'not in' filter must be a list, set, or tuple." - ) - if negate: - return reduce(operator.and_, (operator.ne(column, v) for v in value)) + return filters else: - return reduce(operator.or_, (operator.eq(column, v) for v in value)) + return None -def _handle_is(column, value, *, negate): - if value not in {np.nan, None}: - raise TypeError( - "Value of 'is' or 'is not' filter must be np.nan or None." - ) - return ~column.isna() if negate else column.isna() - +def _apply_post_filters( + df: cudf.DataFrame, filters: List[List[tuple]] | None +) -> cudf.DataFrame: + """Apply DNF filters to an in-memory DataFrame -def _apply_post_filters(df: cudf.DataFrame, filters: List[List[tuple]] | None): - # Apply DNF filters to an in-memory DataFrame - # - # Disjunctive normal form (DNF) means that the inner-most - # tuple describes a single column predicate. These inner - # predicates are combined with an AND conjunction into a - # larger predicate. The outer-most list then combines all - # of the combined filters with an OR disjunction. + Disjunctive normal form (DNF) means that the inner-most + tuple describes a single column predicate. These inner + predicates are combined with an AND conjunction into a + larger predicate. The outer-most list then combines all + of the combined filters with an OR disjunction. + """ if not filters: # No filters to apply return df + def _handle_eq(column: cudf.Series, value, *, negate) -> cudf.Series: + return column != value if negate else column == value + + def _handle_gt(column: cudf.Series, value, *, negate) -> cudf.Series: + return column <= value if negate else column > value + + def _handle_lt(column: cudf.Series, value, *, negate) -> cudf.Series: + return column >= value if negate else column < value + + def _handle_in(column: cudf.Series, value, *, negate) -> cudf.Series: + if not isinstance(value, (list, set, tuple)): + raise TypeError( + "Value of 'in'/'not in' filter must be a list, set, or tuple." + ) + return ~column.isin(value) if negate else column.isin(value) + + def _handle_is(column: cudf.Series, value, *, negate) -> cudf.Series: + if value not in {np.nan, None}: + raise TypeError( + "Value of 'is'/'is not' filter must be np.nan or None." + ) + return ~column.isna() if negate else column.isna() + + def _handle_and(left: cudf.Series, right: cudf.Series) -> cudf.Series: + return left & right + + def _handle_or(left: cudf.Series, right: cudf.Series) -> cudf.Series: + return left | right + handlers: Dict[str, Callable] = { - "==": operator.eq, - "!=": operator.ne, - "<": operator.lt, - "<=": operator.le, - ">": operator.gt, - ">=": operator.ge, + "==": partial(_handle_eq, negate=False), + "!=": partial(_handle_eq, negate=True), + "<": partial(_handle_lt, negate=False), + ">=": partial(_handle_lt, negate=True), + ">": partial(_handle_gt, negate=False), + "<=": partial(_handle_gt, negate=True), "in": partial(_handle_in, negate=False), "not in": partial(_handle_in, negate=True), "is": partial(_handle_is, negate=False), @@ -656,10 +670,10 @@ def _apply_post_filters(df: cudf.DataFrame, filters: List[List[tuple]] | None): try: selection: cudf.Series = reduce( - operator.or_, + _handle_or, ( reduce( - operator.and_, + _handle_and, ( handlers[op](df[column], value) for (column, op, value) in expr From adc435803ff62a44302f842c887a277da8b4eac6 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 16 May 2023 07:24:38 -0700 Subject: [PATCH 09/10] add operator back --- python/cudf/cudf/io/parquet.py | 32 +++++++++----------------------- 1 file changed, 9 insertions(+), 23 deletions(-) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index f6a24114b7d..337803789c4 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -2,6 +2,7 @@ from __future__ import annotations import math +import operator import shutil import tempfile import warnings @@ -616,15 +617,6 @@ def _apply_post_filters( # No filters to apply return df - def _handle_eq(column: cudf.Series, value, *, negate) -> cudf.Series: - return column != value if negate else column == value - - def _handle_gt(column: cudf.Series, value, *, negate) -> cudf.Series: - return column <= value if negate else column > value - - def _handle_lt(column: cudf.Series, value, *, negate) -> cudf.Series: - return column >= value if negate else column < value - def _handle_in(column: cudf.Series, value, *, negate) -> cudf.Series: if not isinstance(value, (list, set, tuple)): raise TypeError( @@ -639,19 +631,13 @@ def _handle_is(column: cudf.Series, value, *, negate) -> cudf.Series: ) return ~column.isna() if negate else column.isna() - def _handle_and(left: cudf.Series, right: cudf.Series) -> cudf.Series: - return left & right - - def _handle_or(left: cudf.Series, right: cudf.Series) -> cudf.Series: - return left | right - handlers: Dict[str, Callable] = { - "==": partial(_handle_eq, negate=False), - "!=": partial(_handle_eq, negate=True), - "<": partial(_handle_lt, negate=False), - ">=": partial(_handle_lt, negate=True), - ">": partial(_handle_gt, negate=False), - "<=": partial(_handle_gt, negate=True), + "==": operator.eq, + "!=": operator.ne, + "<": operator.lt, + "<=": operator.le, + ">": operator.gt, + ">=": operator.ge, "in": partial(_handle_in, negate=False), "not in": partial(_handle_in, negate=True), "is": partial(_handle_is, negate=False), @@ -670,10 +656,10 @@ def _handle_or(left: cudf.Series, right: cudf.Series) -> cudf.Series: try: selection: cudf.Series = reduce( - _handle_or, + operator.or_, ( reduce( - _handle_and, + operator.and_, ( handlers[op](df[column], value) for (column, op, value) in expr From 2193bacb448a6b5d746d580477715c2f5b291324 Mon Sep 17 00:00:00 2001 From: rjzamora Date: Tue, 16 May 2023 09:12:18 -0700 Subject: [PATCH 10/10] move early return --- python/cudf/cudf/io/parquet.py | 43 +++++++++++++++++----------------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/python/cudf/cudf/io/parquet.py b/python/cudf/cudf/io/parquet.py index 337803789c4..6c91a5a8765 100644 --- a/python/cudf/cudf/io/parquet.py +++ b/python/cudf/cudf/io/parquet.py @@ -574,31 +574,30 @@ def read_parquet( def _normalize_filters(filters: list | None) -> List[List[tuple]] | None: # Utility to normalize and validate the `filters` # argument to `read_parquet` - if filters: - msg = ( - f"filters must be None, or non-empty List[Tuple] " - f"or List[List[Tuple]]. Got {filters}" - ) - if not isinstance(filters, list): - raise TypeError(msg) + if not filters: + return None - def _validate_predicate(item): - if not isinstance(item, tuple) or len(item) != 3: - raise TypeError( - f"Predicate must be Tuple[str, str, Any], " - f"got {predicate}." - ) + msg = ( + f"filters must be None, or non-empty List[Tuple] " + f"or List[List[Tuple]]. Got {filters}" + ) + if not isinstance(filters, list): + raise TypeError(msg) - filters = filters if isinstance(filters[0], list) else [filters] - for conjunction in filters: - if not conjunction or not isinstance(conjunction, list): - raise TypeError(msg) - for predicate in conjunction: - _validate_predicate(predicate) + def _validate_predicate(item): + if not isinstance(item, tuple) or len(item) != 3: + raise TypeError( + f"Predicate must be Tuple[str, str, Any], " f"got {predicate}." + ) - return filters - else: - return None + filters = filters if isinstance(filters[0], list) else [filters] + for conjunction in filters: + if not conjunction or not isinstance(conjunction, list): + raise TypeError(msg) + for predicate in conjunction: + _validate_predicate(predicate) + + return filters def _apply_post_filters(