Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python: Add initial TableScan implementation #6145

Merged
merged 1 commit into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 103 additions & 0 deletions python/pyiceberg/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,20 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import (
Any,
Dict,
List,
Optional,
Tuple,
)

from pydantic import Field

from pyiceberg.expressions import AlwaysTrue, And, BooleanExpression
from pyiceberg.io import FileIO, load_file_io
from pyiceberg.schema import Schema
from pyiceberg.table.metadata import TableMetadata
Expand Down Expand Up @@ -58,6 +62,29 @@ def refresh(self):
"""Refresh the current table metadata"""
raise NotImplementedError("To be implemented")

def name(self) -> Identifier:
"""Return the identifier of this table"""
return self.identifier

def scan(
self,
row_filter: Optional[BooleanExpression] = None,
partition_filter: Optional[BooleanExpression] = None,
selected_fields: Tuple[str] = ("*",),
case_sensitive: bool = True,
snapshot_id: Optional[int] = None,
options: Properties = EMPTY_DICT,
) -> TableScan:
return TableScan(
table=self,
row_filter=row_filter or AlwaysTrue(),
partition_filter=partition_filter or AlwaysTrue(),
selected_fields=selected_fields,
case_sensitive=case_sensitive,
snapshot_id=snapshot_id,
options=options,
)

def schema(self) -> Schema:
"""Return the schema for this table"""
return next(schema for schema in self.metadata.schemas if schema.schema_id == self.metadata.current_schema_id)
Expand Down Expand Up @@ -123,3 +150,79 @@ def __eq__(self, other: Any) -> bool:
if isinstance(other, Table)
else False
)


class TableScan:
table: Table
row_filter: BooleanExpression
partition_filter: BooleanExpression
selected_fields: Tuple[str]
case_sensitive: bool
snapshot_id: Optional[int]
options: Properties

def __init__(
self,
table: Table,
row_filter: Optional[BooleanExpression] = None,
partition_filter: Optional[BooleanExpression] = None,
selected_fields: Tuple[str] = ("*",),
case_sensitive: bool = True,
snapshot_id: Optional[int] = None,
options: Properties = EMPTY_DICT,
):
self.table = table
self.row_filter = row_filter or AlwaysTrue()
self.partition_filter = partition_filter or AlwaysTrue()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

self.selected_fields = selected_fields
self.case_sensitive = case_sensitive
self.snapshot_id = snapshot_id
self.options = options

def snapshot(self) -> Optional[Snapshot]:
if self.snapshot_id:
return self.table.snapshot_by_id(self.snapshot_id)
return self.table.current_snapshot()

def projection(self) -> Schema:
snapshot_schema = self.table.schema()
if snapshot := self.snapshot():
if snapshot_schema_id := snapshot.schema_id:
snapshot_schema = self.table.schemas()[snapshot_schema_id]

if "*" in self.selected_fields:
return snapshot_schema

return snapshot_schema.select(*self.selected_fields, case_sensitive=self.case_sensitive)

def plan_files(self):
rdblue marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError("Not yet implemented")

def to_arrow(self):
raise NotImplementedError("Not yet implemented")

def update(self, **overrides) -> TableScan:
"""Creates a copy of this table scan with updated fields."""
return TableScan(**{**self.__dict__, **overrides})

def use_ref(self, name: str):
if self.snapshot_id:
raise ValueError(f"Cannot override ref, already set snapshot id={self.snapshot_id}")
if snapshot := self.table.snapshot_by_name(name):
return self.update(snapshot_id=snapshot.snapshot_id)

raise ValueError(f"Cannot scan unknown ref={name}")

def select(self, *field_names: str) -> TableScan:
if "*" in self.selected_fields:
return self.update(selected_fields=field_names)
return self.update(selected_fields=tuple(set(self.selected_fields).intersection(set(field_names))))

def filter_rows(self, new_row_filter: BooleanExpression) -> TableScan:
return self.update(row_filter=And(self.row_filter, new_row_filter))

def filter_partitions(self, new_partition_filter: BooleanExpression) -> TableScan:
return self.update(partition_filter=And(self.partition_filter, new_partition_filter))

def with_case_sensitive(self, case_sensitive: bool = True) -> TableScan:
return self.update(case_sensitive=case_sensitive)
83 changes: 83 additions & 0 deletions python/tests/table/test_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@

import pytest

from pyiceberg.expressions import (
AlwaysTrue,
And,
EqualTo,
In,
)
from pyiceberg.schema import Schema
from pyiceberg.table import PartitionSpec, Table
from pyiceberg.table.metadata import TableMetadataV2
Expand Down Expand Up @@ -179,3 +185,80 @@ def test_history(table):
SnapshotLogEntry(snapshot_id="3051729675574597004", timestamp_ms=1515100955770),
SnapshotLogEntry(snapshot_id="3055729675574597004", timestamp_ms=1555100955770),
]


def test_table_scan_select(table: Table):
scan = table.scan()
assert scan.selected_fields == ("*",)
assert scan.select("a", "b").selected_fields == ("a", "b")
assert scan.select("a", "c").select("a").selected_fields == ("a",)


def test_table_scan_row_filter(table: Table):
scan = table.scan()
assert scan.row_filter == AlwaysTrue()
assert scan.filter_rows(EqualTo("x", 10)).row_filter == EqualTo("x", 10) # type: ignore
assert scan.filter_rows(EqualTo("x", 10)).filter_rows(In("y", (10, 11))).row_filter == And( # type: ignore
EqualTo("x", 10), In("y", (10, 11)) # type: ignore
)


def test_table_scan_partition_filter(table: Table):
scan = table.scan()
assert scan.row_filter == AlwaysTrue()
assert scan.filter_partitions(EqualTo("x", 10)).partition_filter == EqualTo("x", 10) # type: ignore
assert scan.filter_partitions(EqualTo("x", 10)).filter_partitions(In("y", (10, 11))).partition_filter == And( # type: ignore
EqualTo("x", 10), In("y", (10, 11)) # type: ignore
)


def test_table_scan_ref(table: Table):
scan = table.scan()
assert scan.use_ref("test").snapshot_id == 3051729675574597004


def test_table_scan_ref_does_not_exists(table: Table):
scan = table.scan()

with pytest.raises(ValueError) as exc_info:
_ = scan.use_ref("boom")

assert "Cannot scan unknown ref=boom" in str(exc_info.value)


def test_table_scan_projection_full_schema(table: Table):
scan = table.scan()
assert scan.select("x", "y", "z").projection() == Schema(
NestedField(field_id=1, name="x", field_type=LongType(), required=True),
NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"),
NestedField(field_id=3, name="z", field_type=LongType(), required=True),
schema_id=1,
identifier_field_ids=[1, 2],
)


def test_table_scan_projection_single_column(table: Table):
scan = table.scan()
assert scan.select("y").projection() == Schema(
NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"),
schema_id=1,
identifier_field_ids=[2],
)


def test_table_scan_projection_single_column_case_sensitive(table: Table):
scan = table.scan()
assert scan.with_case_sensitive(False).select("Y").projection() == Schema(
NestedField(field_id=2, name="y", field_type=LongType(), required=True, doc="comment"),
schema_id=1,
identifier_field_ids=[2],
)


def test_table_scan_projection_unknown_column(table: Table):
scan = table.scan()

with pytest.raises(ValueError) as exc_info:
_ = scan.select("a").projection()

assert "Could not find column: 'a'" in str(exc_info.value)