Skip to content

Commit

Permalink
[Python] Initial PyArrow writer (#566)
Browse files Browse the repository at this point in the history
* Initial writer implementation

* Add basic partitioning support

* Update docs and link to other projects

* Add Pandas support

* Test writer stats and partitioning

* Test statistics

* Enforce protocol version

* Add experimental to docstring

* Need tying extensions for checking now

* Add nipick ignore for typing_extensions
  • Loading branch information
wjones127 authored Mar 20, 2022
1 parent ff225e1 commit 346f51a
Show file tree
Hide file tree
Showing 33 changed files with 882 additions and 44 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,6 @@ tlaplus/*.toolbox/*/MC.cfg
tlaplus/*.toolbox/*/[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*-[0-9]*/
/.idea
.vscode
.env
.env
**/.DS_Store
**/.python-version
2 changes: 1 addition & 1 deletion README.adoc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ link:https://github.com/rajasekarv/vega[vega], etc. It also provides bindings to

| High-level file writer
|
|
| link:https://github.com/delta-io/delta-rs/issues/542[#542]
|

| Optimize
Expand Down
1 change: 1 addition & 0 deletions python/deltalake/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@
from .deltalake import PyDeltaTableError, RawDeltaTable, rust_core_version
from .schema import DataType, Field, Schema
from .table import DeltaTable, Metadata
from .writer import write_deltalake
4 changes: 2 additions & 2 deletions python/deltalake/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def pyarrow_datatype_from_dict(json_dict: Dict[str, Any]) -> pyarrow.DataType:
key,
pyarrow.list_(
pyarrow.field(
"element", pyarrow.struct([pyarrow_field_from_dict(value_type)])
"entries", pyarrow.struct([pyarrow_field_from_dict(value_type)])
)
),
)
Expand All @@ -218,7 +218,7 @@ def pyarrow_datatype_from_dict(json_dict: Dict[str, Any]) -> pyarrow.DataType:
elif type_class == "list":
field = json_dict["children"][0]
element_type = pyarrow_datatype_from_dict(field)
return pyarrow.list_(pyarrow.field("element", element_type))
return pyarrow.list_(pyarrow.field("item", element_type))
elif type_class == "struct":
fields = [pyarrow_field_from_dict(field) for field in json_dict["children"]]
return pyarrow.struct(fields)
Expand Down
10 changes: 9 additions & 1 deletion python/deltalake/table.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import warnings
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Any, Dict, List, NamedTuple, Optional, Tuple, Union

import pyarrow
import pyarrow.fs as pa_fs
Expand Down Expand Up @@ -63,6 +63,11 @@ def __str__(self) -> str:
)


class ProtocolVersions(NamedTuple):
min_reader_version: int
min_writer_version: int


@dataclass(init=False)
class DeltaTable:
"""Create a DeltaTable instance."""
Expand Down Expand Up @@ -219,6 +224,9 @@ def metadata(self) -> Metadata:
"""
return self._metadata

def protocol(self) -> ProtocolVersions:
return ProtocolVersions(*self._table.protocol_versions())

def history(self, limit: Optional[int] = None) -> List[Dict[str, Any]]:
"""
Run the history command on the DeltaTable.
Expand Down
243 changes: 243 additions & 0 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
import json
import uuid
from dataclasses import dataclass
from datetime import date, datetime
from decimal import Decimal
from typing import Any, Dict, Iterable, Iterator, List, Mapping, Optional, Union

import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
import pyarrow.fs as pa_fs
from pyarrow.lib import RecordBatchReader
from typing_extensions import Literal

from .deltalake import PyDeltaTableError
from .deltalake import write_new_deltalake as _write_new_deltalake
from .table import DeltaTable


class DeltaTableProtocolError(PyDeltaTableError):
pass


@dataclass
class AddAction:
path: str
size: int
partition_values: Mapping[str, Optional[str]]
modification_time: int
data_change: bool
stats: str


def write_deltalake(
table_or_uri: Union[str, DeltaTable],
data: Union[
pd.DataFrame,
pa.Table,
pa.RecordBatch,
Iterable[pa.RecordBatch],
RecordBatchReader,
],
schema: Optional[pa.Schema] = None,
partition_by: Optional[List[str]] = None,
filesystem: Optional[pa_fs.FileSystem] = None,
mode: Literal["error", "append", "overwrite", "ignore"] = "error",
) -> None:
"""Write to a Delta Lake table (Experimental)
If the table does not already exist, it will be created.
This function only supports protocol version 1 currently. If an attempting
to write to an existing table with a higher min_writer_version, this
function will throw DeltaTableProtocolError.
:param table_or_uri: URI of a table or a DeltaTable object.
:param data: Data to write. If passing iterable, the schema must also be given.
:param schema: Optional schema to write.
:param partition_by: List of columns to partition the table by. Only required
when creating a new table.
:param filesystem: Optional filesystem to pass to PyArrow. If not provided will
be inferred from uri.
:param mode: How to handle existing data. Default is to error if table
already exists. If 'append', will add new data. If 'overwrite', will
replace table with new data. If 'ignore', will not write anything if
table already exists.
"""
if isinstance(data, pd.DataFrame):
data = pa.Table.from_pandas(data)

if schema is None:
if isinstance(data, RecordBatchReader):
schema = data.schema
elif isinstance(data, Iterable):
raise ValueError("You must provide schema if data is Iterable")
else:
schema = data.schema

if isinstance(table_or_uri, str):
table = try_get_deltatable(table_or_uri)
table_uri = table_or_uri
else:
table = table_or_uri
table_uri = table_uri = table._table.table_uri()

# TODO: Pass through filesystem once it is complete
# if filesystem is None:
# filesystem = pa_fs.PyFileSystem(DeltaStorageHandler(table_uri))

if table: # already exists
if mode == "error":
raise AssertionError("DeltaTable already exists.")
elif mode == "ignore":
return

current_version = table.version()

if partition_by:
assert partition_by == table.metadata().partition_columns

if table.protocol().min_writer_version > 1:
raise DeltaTableProtocolError(
"This table's min_writer_version is "
f"{table.protocol().min_writer_version}, "
"but this method only supports version 1."
)
else: # creating a new table
current_version = -1

# TODO: Don't allow writing to non-empty directory
# Blocked on: Finish filesystem implementation in fs.py
# assert len(filesystem.get_file_info(pa_fs.FileSelector(table_uri, allow_not_found=True))) == 0

if partition_by:
partition_schema = pa.schema([schema.field(name) for name in partition_by])
partitioning = ds.partitioning(partition_schema, flavor="hive")
else:
partitioning = None

add_actions: List[AddAction] = []

def visitor(written_file: Any) -> None:
partition_values = get_partitions_from_path(table_uri, written_file.path)
stats = get_file_stats_from_metadata(written_file.metadata)

add_actions.append(
AddAction(
written_file.path,
written_file.metadata.serialized_size,
partition_values,
int(datetime.now().timestamp()),
True,
json.dumps(stats, cls=DeltaJSONEncoder),
)
)

ds.write_dataset(
data,
base_dir=table_uri,
basename_template=f"{current_version + 1}-{uuid.uuid4()}-{{i}}.parquet",
format="parquet",
partitioning=partitioning,
# It will not accept a schema if using a RBR
schema=schema if not isinstance(data, RecordBatchReader) else None,
file_visitor=visitor,
existing_data_behavior="overwrite_or_ignore",
)

if table is None:
_write_new_deltalake(table_uri, schema, add_actions, mode, partition_by or [])
else:
table._table.create_write_transaction(
add_actions,
mode,
partition_by or [],
)


class DeltaJSONEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any:
if isinstance(obj, bytes):
return obj.decode("unicode_escape")
elif isinstance(obj, date):
return obj.isoformat()
elif isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, Decimal):
return str(obj)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)


def try_get_deltatable(table_uri: str) -> Optional[DeltaTable]:
try:
return DeltaTable(table_uri)
except PyDeltaTableError as err:
if "Not a Delta table" not in str(err):
raise
return None


def get_partitions_from_path(base_path: str, path: str) -> Dict[str, str]:
path = path.split(base_path, maxsplit=1)[1]
parts = path.split("/")
parts.pop() # remove filename
out = {}
for part in parts:
if part == "":
continue
key, value = part.split("=", maxsplit=1)
out[key] = value
return out


def get_file_stats_from_metadata(
metadata: Any,
) -> Dict[str, Union[int, Dict[str, Any]]]:
stats = {
"numRecords": metadata.num_rows,
"minValues": {},
"maxValues": {},
"nullCount": {},
}

def iter_groups(metadata: Any) -> Iterator[Any]:
for i in range(metadata.num_row_groups):
yield metadata.row_group(i)

for column_idx in range(metadata.num_columns):
name = metadata.row_group(0).column(column_idx).path_in_schema
# If stats missing, then we can't know aggregate stats
if all(
group.column(column_idx).is_stats_set for group in iter_groups(metadata)
):
stats["nullCount"][name] = sum(
group.column(column_idx).statistics.null_count
for group in iter_groups(metadata)
)

# I assume for now this is based on data type, and thus is
# consistent between groups
if metadata.row_group(0).column(column_idx).statistics.has_min_max:
# Min and Max are recorded in physical type, not logical type
# https://stackoverflow.com/questions/66753485/decoding-parquet-min-max-statistics-for-decimal-type
# TODO: Add logic to decode physical type for DATE, DECIMAL
logical_type = (
metadata.row_group(0)
.column(column_idx)
.statistics.logical_type.type
)
#
if logical_type not in ["STRING", "INT", "TIMESTAMP", "NONE"]:
continue
# import pdb; pdb.set_trace()
stats["minValues"][name] = min(
group.column(column_idx).statistics.min
for group in iter_groups(metadata)
)
stats["maxValues"][name] = max(
group.column(column_idx).statistics.max
for group in iter_groups(metadata)
)
return stats
5 changes: 5 additions & 0 deletions python/docs/source/api_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ DeltaTable
.. automodule:: deltalake.table
:members:

Writing DeltaTables
-------------------

.. autofunction:: deltalake.write_deltalake

DeltaSchema
-----------

Expand Down
15 changes: 14 additions & 1 deletion python/docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,12 @@ def get_release_version() -> str:
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ["sphinx_rtd_theme", "sphinx.ext.autodoc", "edit_on_github"]
extensions = [
"sphinx_rtd_theme",
"sphinx.ext.autodoc",
"sphinx.ext.intersphinx",
"edit_on_github",
]
autodoc_typehints = "description"
nitpicky = True
nitpick_ignore = [
Expand All @@ -52,6 +57,7 @@ def get_release_version() -> str:
("py:class", "pyarrow.lib.DataType"),
("py:class", "pyarrow.lib.Field"),
("py:class", "pyarrow.lib.NativeFile"),
("py:class", "pyarrow.lib.RecordBatchReader"),
("py:class", "pyarrow._fs.FileSystem"),
("py:class", "pyarrow._fs.FileInfo"),
("py:class", "pyarrow._fs.FileSelector"),
Expand Down Expand Up @@ -84,3 +90,10 @@ def get_release_version() -> str:
edit_on_github_project = "delta-io/delta-rs"
edit_on_github_branch = "main"
page_source_prefix = "python/docs/source"


intersphinx_mapping = {
"pyarrow": ("https://arrow.apache.org/docs/", None),
"pyspark": ("https://spark.apache.org/docs/latest/api/python/", None),
"pandas": ("https://pandas.pydata.org/docs/", None),
}
30 changes: 29 additions & 1 deletion python/docs/source/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -328,4 +328,32 @@ Optimizing tables is not currently supported.
Writing Delta Tables
--------------------

Writing Delta tables is not currently supported.
.. py:currentmodule:: deltalake
.. warning::
The writer is currently *experimental*. Please use on test data first, not
on production data. Report any issues at https://github.com/delta-io/delta-rs/issues.

For overwrites and appends, use :py:func:`write_deltalake`. If the table does not
already exist, it will be created. The ``data`` parameter will accept a Pandas
DataFrame, a PyArrow Table, or an iterator of PyArrow Record Batches.

.. code-block:: python
>>> from deltalake.writer import write_deltalake
>>> df = pd.DataFrame({'x': [1, 2, 3]})
>>> write_deltalake('path/to/table', df)
.. note::
:py:func:`write_deltalake` accepts a Pandas DataFrame, but will convert it to
a Arrow table before writing. See caveats in :doc:`pyarrow:python/pandas`.

By default, writes create a new table and error if it already exists. This is
controlled by the ``mode`` parameter, which mirrors the behavior of Spark's
:py:meth:`pyspark.sql.DataFrameWriter.saveAsTable` DataFrame method. To overwrite pass in ``mode='overwrite'`` and
to append pass in ``mode='append'``:

.. code-block:: python
>>> write_deltalake('path/to/table', df, mode='overwrite')
>>> write_deltalake('path/to/table', df, mode='append')
Loading

0 comments on commit 346f51a

Please sign in to comment.