Skip to content

Commit

Permalink
Use Pydantic's model_copy for model modification (#182)
Browse files Browse the repository at this point in the history
* Implement table metadata updater first draft

* fix updater error and add tests

* implement apply_metadata_update which is simpler

* remove old implementation

* re-organize method place

* fix nit

* fix test

* add another test

* clear TODO

* add a combined test

* Fix merge conflict

* remove table requirement validation for PR simplification

* make context private and solve elif issue

* remove private field access

* push snapshot ref validation to its builder using pydantic

* fix comment

* remove unnecessary code for AddSchemaUpdate update

* replace if with elif

* switch to model_copy()

* enhance the set current schema update implementation and some other changes

* make apply_table_update private

* fix lint after merge

* add validation

* add test for isolation of illegal updates

* fix nit

* remove unnecessary flag

* change to model_copy(deep=True)
  • Loading branch information
HonahX authored Dec 6, 2023
1 parent 34b18e4 commit a368bd9
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 30 deletions.
57 changes: 28 additions & 29 deletions pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,12 +417,13 @@ def _(update: AddSchemaUpdate, base_metadata: TableMetadata, context: _TableMeta
if update.last_column_id < base_metadata.last_column_id:
raise ValueError(f"Invalid last column id {update.last_column_id}, must be >= {base_metadata.last_column_id}")

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["last-column-id"] = update.last_column_id
updated_metadata_data["schemas"].append(update.schema_.model_dump())

context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)
return base_metadata.model_copy(
update={
"last_column_id": update.last_column_id,
"schemas": base_metadata.schemas + [update.schema_],
}
)


@_apply_table_update.register(SetCurrentSchemaUpdate)
Expand All @@ -441,11 +442,8 @@ def _(update: SetCurrentSchemaUpdate, base_metadata: TableMetadata, context: _Ta
if schema is None:
raise ValueError(f"Schema with id {new_schema_id} does not exist")

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["current-schema-id"] = new_schema_id

context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)
return base_metadata.model_copy(update={"current_schema_id": new_schema_id})


@_apply_table_update.register(AddSnapshotUpdate)
Expand All @@ -469,12 +467,14 @@ def _(update: AddSnapshotUpdate, base_metadata: TableMetadata, context: _TableMe
f"older than last sequence number {base_metadata.last_sequence_number}"
)

updated_metadata_data = copy(base_metadata.model_dump())
updated_metadata_data["last-updated-ms"] = update.snapshot.timestamp_ms
updated_metadata_data["last-sequence-number"] = update.snapshot.sequence_number
updated_metadata_data["snapshots"].append(update.snapshot.model_dump())
context.add_update(update)
return TableMetadataUtil.parse_obj(updated_metadata_data)
return base_metadata.model_copy(
update={
"last_updated_ms": update.snapshot.timestamp_ms,
"last_sequence_number": update.snapshot.sequence_number,
"snapshots": base_metadata.snapshots + [update.snapshot],
}
)


@_apply_table_update.register(SetSnapshotRefUpdate)
Expand All @@ -493,28 +493,27 @@ def _(update: SetSnapshotRefUpdate, base_metadata: TableMetadata, context: _Tabl

snapshot = base_metadata.snapshot_by_id(snapshot_ref.snapshot_id)
if snapshot is None:
raise ValueError(f"Cannot set {snapshot_ref.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}")
raise ValueError(f"Cannot set {update.ref_name} to unknown snapshot {snapshot_ref.snapshot_id}")

update_metadata_data = copy(base_metadata.model_dump())
update_last_updated_ms = True
metadata_updates: Dict[str, Any] = {}
if context.is_added_snapshot(snapshot_ref.snapshot_id):
update_metadata_data["last-updated-ms"] = snapshot.timestamp_ms
update_last_updated_ms = False
metadata_updates["last_updated_ms"] = snapshot.timestamp_ms

if update.ref_name == MAIN_BRANCH:
update_metadata_data["current-snapshot-id"] = snapshot_ref.snapshot_id
if update_last_updated_ms:
update_metadata_data["last-updated-ms"] = datetime_to_millis(datetime.datetime.now().astimezone())
update_metadata_data["snapshot-log"].append(
metadata_updates["current_snapshot_id"] = snapshot_ref.snapshot_id
if "last_updated_ms" not in metadata_updates:
metadata_updates["last_updated_ms"] = datetime_to_millis(datetime.datetime.now().astimezone())

metadata_updates["snapshot_log"] = base_metadata.snapshot_log + [
SnapshotLogEntry(
snapshot_id=snapshot_ref.snapshot_id,
timestamp_ms=update_metadata_data["last-updated-ms"],
).model_dump()
)
timestamp_ms=metadata_updates["last_updated_ms"],
)
]

update_metadata_data["refs"][update.ref_name] = snapshot_ref.model_dump()
metadata_updates["refs"] = {**base_metadata.refs, update.ref_name: snapshot_ref}
context.add_update(update)
return TableMetadataUtil.parse_obj(update_metadata_data)
return base_metadata.model_copy(update=metadata_updates)


def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpdate, ...]) -> TableMetadata:
Expand All @@ -533,7 +532,7 @@ def update_table_metadata(base_metadata: TableMetadata, updates: Tuple[TableUpda
for update in updates:
new_metadata = _apply_table_update(update, new_metadata, context)

return new_metadata
return new_metadata.model_copy(deep=True)


class TableRequirement(IcebergBaseModel):
Expand Down
51 changes: 50 additions & 1 deletion tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint:disable=redefined-outer-name
from copy import copy
from typing import Dict

import pytest
Expand Down Expand Up @@ -50,7 +51,7 @@
_TableMetadataUpdateContext,
update_table_metadata,
)
from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER
from pyiceberg.table.metadata import INITIAL_SEQUENCE_NUMBER, TableMetadataUtil, TableMetadataV2
from pyiceberg.table.snapshots import (
Operation,
Snapshot,
Expand Down Expand Up @@ -640,9 +641,12 @@ def test_update_metadata_with_multiple_updates(table_v1: Table) -> None:
)

new_metadata = update_table_metadata(base_metadata, test_updates)
# rebuild the metadata to trigger validation
new_metadata = TableMetadataUtil.parse_obj(copy(new_metadata.model_dump()))

# UpgradeFormatVersionUpdate
assert new_metadata.format_version == 2
assert isinstance(new_metadata, TableMetadataV2)

# UpdateSchema
assert len(new_metadata.schemas) == 2
Expand All @@ -669,6 +673,51 @@ def test_update_metadata_with_multiple_updates(table_v1: Table) -> None:
)


def test_metadata_isolation_from_illegal_updates(table_v1: Table) -> None:
base_metadata = table_v1.metadata
base_metadata_backup = base_metadata.model_copy(deep=True)

# Apply legal updates on the table metadata
transaction = table_v1.transaction()
schema_update_1 = transaction.update_schema()
schema_update_1.add_column(path="b", field_type=IntegerType())
schema_update_1.commit()
test_updates = transaction._updates # pylint: disable=W0212
new_snapshot = Snapshot(
snapshot_id=25,
parent_snapshot_id=19,
sequence_number=200,
timestamp_ms=1602638573590,
manifest_list="s3:/a/b/c.avro",
summary=Summary(Operation.APPEND),
schema_id=3,
)
test_updates += (
AddSnapshotUpdate(snapshot=new_snapshot),
SetSnapshotRefUpdate(
ref_name="main",
type="branch",
snapshot_id=25,
max_ref_age_ms=123123123,
max_snapshot_age_ms=12312312312,
min_snapshots_to_keep=1,
),
)
new_metadata = update_table_metadata(base_metadata, test_updates)

# Check that the original metadata is not modified
assert base_metadata == base_metadata_backup

# Perform illegal update on the new metadata:
# TableMetadata should be immutable, but the pydantic's frozen config cannot prevent
# operations such as list append.
new_metadata.partition_specs.append(PartitionSpec(spec_id=0))
assert len(new_metadata.partition_specs) == 2

# The original metadata should not be affected by the illegal update on the new metadata
assert len(base_metadata.partition_specs) == 1


def test_generate_snapshot_id(table_v2: Table) -> None:
assert isinstance(_generate_snapshot_id(), int)
assert isinstance(table_v2.new_snapshot_id(), int)

0 comments on commit a368bd9

Please sign in to comment.