Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster parquet streaming + filters with predicate pushdown #7309

Merged
merged 2 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions src/datasets/packaged_modules/parquet/parquet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import itertools
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Union

import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.parquet as pq

import datasets
Expand All @@ -19,6 +20,7 @@ class ParquetConfig(datasets.BuilderConfig):
batch_size: Optional[int] = None
columns: Optional[List[str]] = None
features: Optional[datasets.Features] = None
filters: Optional[Union[ds.Expression, List[tuple], List[List[tuple]]]] = None

def __post_init__(self):
super().__post_init__()
Expand Down Expand Up @@ -77,14 +79,25 @@ def _generate_tables(self, files):
raise ValueError(
f"Tried to load parquet data with columns '{self.config.columns}' with mismatching features '{self.info.features}'"
)
filter_expr = (
pq.filters_to_expression(self.config.filters)
if isinstance(self.config.filters, list)
else self.config.filters
)
for file_idx, file in enumerate(itertools.chain.from_iterable(files)):
with open(file, "rb") as f:
parquet_file = pq.ParquetFile(f)
if parquet_file.metadata.num_row_groups > 0:
batch_size = self.config.batch_size or parquet_file.metadata.row_group(0).num_rows
parquet_fragment = ds.ParquetFileFormat().make_fragment(f)
if parquet_fragment.row_groups:
batch_size = self.config.batch_size or parquet_fragment.row_groups[0].num_rows
try:
for batch_idx, record_batch in enumerate(
parquet_file.iter_batches(batch_size=batch_size, columns=self.config.columns)
parquet_fragment.to_batches(
batch_size=batch_size,
columns=self.config.columns,
filter=filter_expr,
batch_readahead=0,
fragment_readahead=0,
)
):
pa_table = pa.Table.from_batches([record_batch])
# Uncomment for debugging (will print the Arrow table size and elements)
Expand Down
10 changes: 10 additions & 0 deletions tests/io/test_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,16 @@ def test_parquet_read_geoparquet(geoparquet_path, tmp_path):
assert dataset.features[feature].dtype == expected_dtype


def test_parquet_read_filters(parquet_path, tmp_path):
cache_dir = tmp_path / "cache"
filters = [("col_2", "==", 1)]
dataset = ParquetDatasetReader(path_or_paths=parquet_path, cache_dir=cache_dir, filters=filters).read()

assert isinstance(dataset, Dataset)
assert all(example["col_2"] == 1 for example in dataset)
assert dataset.num_rows == 1


def _check_parquet_datasetdict(dataset_dict, expected_features, splits=("train",)):
assert isinstance(dataset_dict, (DatasetDict, IterableDatasetDict))
for split in splits:
Expand Down
Loading