Skip to content

Commit

Permalink
Python: Add initial TableScan implementation (#6145)
Browse files Browse the repository at this point in the history
  • Loading branch information
Fokko authored Nov 10, 2022
1 parent 048f3f6 commit f54a10c
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 0 deletions.
103 changes: 103 additions & 0 deletions python/pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
)

from pydantic import Field

from pyiceberg.expressions import AlwaysTrue, And, BooleanExpression
from pyiceberg.io import FileIO, load_file_io
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import TableMetadata
Expand Down Expand Up @@ -58,6 +62,29 @@ def refresh(self):
"""Refresh the current table metadata"""
raise NotImplementedError("To be implemented")

def name(self) -> Identifier:
"""Return the identifier of this table"""
return self.identifier

def scan(
self,
row_filter: Optional[BooleanExpression] = None,
partition_filter: Optional[BooleanExpression] = None,
selected_fields: Tuple[str] = ("*",),
case_sensitive: bool = True,
snapshot_id: Optional[int] = None,
options: Properties = EMPTY_DICT,
) -> TableScan:
return TableScan(
table=self,
row_filter=row_filter or AlwaysTrue(),
partition_filter=partition_filter or AlwaysTrue(),
selected_fields=selected_fields,
case_sensitive=case_sensitive,
snapshot_id=snapshot_id,
options=options,
)

def schema(self) -> Schema:
"""Return the schema for this table"""
return next(schema for schema in self.metadata.schemas if schema.schema_id == self.metadata.current_schema_id)
Expand Down Expand Up @@ -123,3 +150,79 @@ def __eq__(self, other: Any) -> bool:
if isinstance(other, Table)
else False
)


class TableScan:
table: Table
row_filter: BooleanExpression
partition_filter: BooleanExpression
selected_fields: Tuple[str]
case_sensitive: bool
snapshot_id: Optional[int]
options: Properties

def __init__(
self,
table: Table,
row_filter: Optional[BooleanExpression] = None,
partition_filter: Optional[BooleanExpression] = None,
selected_fields: Tuple[str] = ("*",),
case_sensitive: bool = True,
snapshot_id: Optional[int] = None,
options: Properties = EMPTY_DICT,
):
self.table = table
self.row_filter = row_filter or AlwaysTrue()
self.partition_filter = partition_filter or AlwaysTrue()
self.selected_fields = selected_fields
self.case_sensitive = case_sensitive
self.snapshot_id = snapshot_id
self.options = options

def snapshot(self) -> Optional[Snapshot]:
if self.snapshot_id:
return self.table.snapshot_by_id(self.snapshot_id)
return self.table.current_snapshot()

def projection(self) -> Schema:
snapshot_schema = self.table.schema()
if snapshot := self.snapshot():
if snapshot_schema_id := snapshot.schema_id:
snapshot_schema = self.table.schemas()[snapshot_schema_id]

if "*" in self.selected_fields:
return snapshot_schema

return snapshot_schema.select(*self.selected_fields, case_sensitive=self.case_sensitive)

def plan_files(self):
raise NotImplementedError("Not yet implemented")

def to_arrow(self):
raise NotImplementedError("Not yet implemented")

def update(self, **overrides) -> TableScan:
"""Creates a copy of this table scan with updated fields."""
return TableScan(**{**self.__dict__, **overrides})

def use_ref(self, name: str):
if self.snapshot_id:
raise ValueError(f"Cannot override ref, already set snapshot id={self.snapshot_id}")
if snapshot := self.table.snapshot_by_name(name):
return self.update(snapshot_id=snapshot.snapshot_id)

raise ValueError(f"Cannot scan unknown ref={name}")

def select(self, *field_names: str) -> TableScan:
if "*" in self.selected_fields:
return self.update(selected_fields=field_names)
return self.update(selected_fields=tuple(set(self.selected_fields).intersection(set(field_names))))

def filter_rows(self, new_row_filter: BooleanExpression) -> TableScan:
return self.update(row_filter=And(self.row_filter, new_row_filter))

def filter_partitions(self, new_partition_filter: BooleanExpression) -> TableScan:
return self.update(partition_filter=And(self.partition_filter, new_partition_filter))

def with_case_sensitive(self, case_sensitive: bool = True) -> TableScan:
return self.update(case_sensitive=case_sensitive)
83 changes: 83 additions & 0 deletions python/tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@

import pytest

from pyiceberg.expressions import (
AlwaysTrue,
And,
EqualTo,
In,
)
from pyiceberg.schema import Schema
from pyiceberg.table import PartitionSpec, Table
from pyiceberg.table.metadata import TableMetadataV2
Expand Down Expand Up @@ -179,3 +185,80 @@ def test_history(table):
SnapshotLogEntry(snapshot_id="3051729675574597004", timestamp_ms=1515100955770),
SnapshotLogEntry(snapshot_id="3055729675574597004", timestamp_ms=1555100955770),
]


def test_table_scan_select(table: Table):
scan = table.scan()
assert scan.selected_fields == ("*",)
assert scan.select("a", "b").selected_fields == ("a", "b")
assert scan.select("a", "c").select("a").selected_fields == ("a",)


def test_table_scan_row_filter(table: Table):
scan = table.scan()
assert scan.row_filter == AlwaysTrue()
assert scan.filter_rows(EqualTo("x", 10)).row_filter == EqualTo("x", 10) # type: ignore
assert scan.filter_rows(EqualTo("x", 10)).filter_rows(In("y", (10, 11))).row_filter == And( # type: ignore
EqualTo("x", 10), In("y", (10, 11)) # type: ignore
)


def test_table_scan_partition_filter(table: Table):
scan = table.scan()
assert scan.row_filter == AlwaysTrue()
assert scan.filter_partitions(EqualTo("x", 10)).partition_filter == EqualTo("x", 10) # type: ignore
assert scan.filter_partitions(EqualTo("x", 10)).filter_partitions(In("y", (10, 11))).partition_filter == And( # type: ignore
EqualTo("x", 10), In("y", (10, 11)) # type: ignore
)


def test_table_scan_ref(table: Table):
scan = table.scan()
assert scan.use_ref("test").snapshot_id == 3051729675574597004


def test_table_scan_ref_does_not_exists(table: Table):
scan = table.scan()

with pytest.raises(ValueError) as exc_info:
_ = scan.use_ref("boom")

assert "Cannot scan unknown ref=boom" in str(exc_info.value)


def test_table_scan_projection_full_schema(table: Table):
scan = table.scan()
assert scan.select("x", "y", "z").projection() == Schema(
NestedField(field_id=1, name="x", field_type=LongType(), required=True),
NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"),
NestedField(field_id=3, name="z", field_type=LongType(), required=True),
schema_id=1,
identifier_field_ids=[1, 2],
)


def test_table_scan_projection_single_column(table: Table):
scan = table.scan()
assert scan.select("y").projection() == Schema(
NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"),
schema_id=1,
identifier_field_ids=[2],
)


def test_table_scan_projection_single_column_case_sensitive(table: Table):
scan = table.scan()
assert scan.with_case_sensitive(False).select("Y").projection() == Schema(
NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"),
schema_id=1,
identifier_field_ids=[2],
)


def test_table_scan_projection_unknown_column(table: Table):
scan = table.scan()

with pytest.raises(ValueError) as exc_info:
_ = scan.select("a").projection()

assert "Could not find column: 'a'" in str(exc_info.value)

0 comments on commit f54a10c

Please sign in to comment.