Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyArrow: Avoid buffer-overflow by avoid doing a sort #1555

Merged
merged 13 commits into from
Jan 23, 2025
Merged
124 changes: 46 additions & 78 deletions pyiceberg/io/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@

import concurrent.futures
import fnmatch
import functools
import itertools
import logging
import operator
import os
import re
import uuid
Expand Down Expand Up @@ -2542,38 +2544,8 @@ class _TablePartition:
arrow_table_partition: pa.Table


def _get_table_partitions(
arrow_table: pa.Table,
partition_spec: PartitionSpec,
schema: Schema,
slice_instructions: list[dict[str, Any]],
) -> list[_TablePartition]:
sorted_slice_instructions = sorted(slice_instructions, key=lambda x: x["offset"])

partition_fields = partition_spec.fields

offsets = [inst["offset"] for inst in sorted_slice_instructions]
projected_and_filtered = {
partition_field.source_id: arrow_table[schema.find_field(name_or_id=partition_field.source_id).name]
.take(offsets)
.to_pylist()
for partition_field in partition_fields
}

table_partitions = []
for idx, inst in enumerate(sorted_slice_instructions):
partition_slice = arrow_table.slice(**inst)
fieldvalues = [
PartitionFieldValue(partition_field, projected_and_filtered[partition_field.source_id][idx])
for partition_field in partition_fields
]
partition_key = PartitionKey(raw_partition_field_values=fieldvalues, partition_spec=partition_spec, schema=schema)
table_partitions.append(_TablePartition(partition_key=partition_key, arrow_table_partition=partition_slice))
return table_partitions


def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.Table) -> List[_TablePartition]:
"""Based on the iceberg table partition spec, slice the arrow table into partitions with their keys.
"""Based on the iceberg table partition spec, filter the arrow table into partitions with their keys.

Example:
Input:
Expand All @@ -2582,54 +2554,50 @@ def _determine_partitions(spec: PartitionSpec, schema: Schema, arrow_table: pa.T
'n_legs': [2, 2, 2, 4, 4, 4, 4, 5, 100],
'animal': ["Flamingo", "Parrot", "Parrot", "Dog", "Horse", "Horse", "Horse","Brittle stars", "Centipede"]}.
The algorithm:
Firstly we group the rows into partitions by sorting with sort order [('n_legs', 'descending'), ('year', 'descending')]
and null_placement of "at_end".
This gives the same table as raw input.
Then we sort_indices using reverse order of [('n_legs', 'descending'), ('year', 'descending')]
and null_placement : "at_start".
This gives:
[8, 7, 4, 5, 6, 3, 1, 2, 0]
Based on this we get partition groups of indices:
[{'offset': 8, 'length': 1}, {'offset': 7, 'length': 1}, {'offset': 4, 'length': 3}, {'offset': 3, 'length': 1}, {'offset': 1, 'length': 2}, {'offset': 0, 'length': 1}]
We then retrieve the partition keys by offsets.
And slice the arrow table by offsets and lengths of each partition.
- We determine the set of unique partition keys
- Then we produce a set of partitions by filtering on each of the combinations
- We combine the chunks to create a copy to avoid GIL congestion on the original table
"""
partition_columns: List[Tuple[PartitionField, NestedField]] = [
(partition_field, schema.find_field(partition_field.source_id)) for partition_field in spec.fields
]
partition_values_table = pa.table(
{
str(partition.field_id): partition.transform.pyarrow_transform(field.field_type)(arrow_table[field.name])
for partition, field in partition_columns
}
)
# Assign unique names to columns where the partition transform has been applied
# to avoid conflicts
partition_fields = [f"_partition_{field.name}" for field in spec.fields]

for partition, name in zip(spec.fields, partition_fields):
source_field = schema.find_field(partition.source_id)
arrow_table = arrow_table.append_column(
name, partition.transform.pyarrow_transform(source_field.field_type)(arrow_table[source_field.name])
)

unique_partition_fields = arrow_table.select(partition_fields).group_by(partition_fields).aggregate([])

table_partitions = []
# TODO: As a next step, we could also play around with yielding instead of materializing the full list
for unique_partition in unique_partition_fields.to_pylist():
partition_key = PartitionKey(
raw_partition_field_values=[
PartitionFieldValue(field=field, value=unique_partition[name])
for field, name in zip(spec.fields, partition_fields)
],
partition_spec=spec,
schema=schema,
)
filtered_table = arrow_table.filter(
functools.reduce(
operator.and_,
[
pc.field(partition_field_name) == unique_partition[partition_field_name]
if unique_partition[partition_field_name] is not None
else pc.field(partition_field_name).is_null()
for field, partition_field_name in zip(spec.fields, partition_fields)
],
)
)
filtered_table = filtered_table.drop_columns(partition_fields)

# Sort by partitions
sort_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "ascending") for col in partition_values_table.column_names],
null_placement="at_end",
).to_pylist()
arrow_table = arrow_table.take(sort_indices)

# Get slice_instructions to group by partitions
partition_values_table = partition_values_table.take(sort_indices)
reversed_indices = pa.compute.sort_indices(
partition_values_table,
sort_keys=[(col, "descending") for col in partition_values_table.column_names],
null_placement="at_start",
).to_pylist()
slice_instructions: List[Dict[str, Any]] = []
last = len(reversed_indices)
reversed_indices_size = len(reversed_indices)
ptr = 0
while ptr < reversed_indices_size:
group_size = last - reversed_indices[ptr]
offset = reversed_indices[ptr]
slice_instructions.append({"offset": offset, "length": group_size})
last = reversed_indices[ptr]
ptr = ptr + group_size

table_partitions: List[_TablePartition] = _get_table_partitions(arrow_table, spec, schema, slice_instructions)
# The combine_chunks seems to be counter-intuitive to do, but it actually returns
# fresh buffers that don't interfere with each other when it is written out to file
table_partitions.append(
_TablePartition(partition_key=partition_key, arrow_table_partition=filtered_table.combine_chunks())
)

return table_partitions
6 changes: 4 additions & 2 deletions pyiceberg/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,8 +413,10 @@ def partition_record_value(partition_field: PartitionField, value: Any, schema:
the final partition record value.
"""
iceberg_type = schema.find_field(name_or_id=partition_field.source_id).field_type
iceberg_typed_value = _to_partition_representation(iceberg_type, value)
transformed_value = partition_field.transform.transform(iceberg_type)(iceberg_typed_value)
if not isinstance(value, int):
# When adding files, it can be that we still need to convert from logical types to physical types
value = _to_partition_representation(iceberg_type, value)
transformed_value = partition_field.transform.transform(iceberg_type)(value)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is causing bugs, I'm going to revisit this to fix it properly

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, so I got to the bottom of it. It has to do with the return types of the transforms. eg. When we apply the bucket transform, the result is always an int, which is great. The problem is with the identity transform where the destination type is equal to the source type. So when a date comes in, it also comes out.

I think in the end it is better to remove the _to_partition_representation and see if we can consolidate this somewhere, but that's a different PR.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So when a date comes in, it also comes out.

is it due to not having support for datetime literal? #1542

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also if its just for adding files, perhaps we can do something special just for that path

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also if its just for adding files, perhaps we can do something special just for that path

Yes, that's exactly what I went for. I think we can simplify the logic in subsequent PRs :)

return transformed_value


Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -1220,6 +1220,7 @@ markers = [
"adls: marks a test as requiring access to adls compliant storage (use with --adls.account-name, --adls.account-key, and --adls.endpoint args)",
"integration: marks integration tests against Apache Spark",
"gcs: marks a test as requiring access to gcs compliant storage (use with --gs.token, --gs.project, and --gs.endpoint)",
"benchmark: collection of tests to validate read/write performance before and after a change"
]

# Turns a warning into an error
Expand Down
72 changes: 72 additions & 0 deletions tests/benchmark/test_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import statistics
import timeit
import urllib

import pyarrow as pa
import pyarrow.parquet as pq
import pytest

from pyiceberg.transforms import DayTransform


@pytest.fixture(scope="session")
def taxi_dataset(tmp_path_factory: pytest.TempPathFactory) -> pa.Table:
"""Reads the Taxi dataset to disk"""
taxi_dataset = "https://d37ci6vzurychx.cloudfront.net/trip-data/yellow_tripdata_2022-01.parquet"
taxi_dataset_dest = tmp_path_factory.mktemp("taxi_dataset") / "yellow_tripdata_2022-01.parquet"
urllib.request.urlretrieve(taxi_dataset, taxi_dataset_dest)

return pq.read_table(taxi_dataset_dest)


@pytest.mark.benchmark
def test_partitioned_write(tmp_path_factory: pytest.TempPathFactory, taxi_dataset: pa.Table) -> None:
"""Tests writing to a partitioned table with something that would be close a production-like situation"""
from pyiceberg.catalog.sql import SqlCatalog

warehouse_path = str(tmp_path_factory.mktemp("warehouse"))
catalog = SqlCatalog(
"default",
uri=f"sqlite:///{warehouse_path}/pyiceberg_catalog.db",
warehouse=f"file://{warehouse_path}",
)

catalog.create_namespace("default")

tbl = catalog.create_table("default.taxi_partitioned", schema=taxi_dataset.schema)

with tbl.update_spec() as spec:
spec.add_field("tpep_pickup_datetime", DayTransform())

# Profiling can sometimes be handy as well
# with cProfile.Profile() as pr:
# tbl.append(taxi_dataset)
#
# pr.print_stats(sort=True)

runs = []
for run in range(5):
start_time = timeit.default_timer()
tbl.append(taxi_dataset)
elapsed = timeit.default_timer() - start_time

print(f"Run {run} took: {elapsed}")
runs.append(elapsed)

print(f"Average runtime of {round(statistics.mean(runs), 2)} seconds")
20 changes: 20 additions & 0 deletions tests/integration/test_writes/test_partitioned_writes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,3 +1126,23 @@ def test_append_multiple_partitions(
"""
)
assert files_df.count() == 6


@pytest.mark.integration
def test_pyarrow_overflow(session_catalog: Catalog) -> None:
"""Test what happens when the offset is beyond 32 bits"""
identifier = "default.arrow_table_overflow"
try:
session_catalog.drop_table(identifier=identifier)
except NoSuchTableError:
pass

arr = ["fixed_string"] * 30_000
strings = pa.chunked_array([arr] * 10_000)
# Create pa.table
arrow_table = pa.table({"a": strings})

table = session_catalog.create_table(identifier, arrow_table.schema)
with table.update_spec() as update_spec:
update_spec.add_field("b", IdentityTransform(), "pb")
table.append(arrow_table)
Loading