Skip to content

Commit

Permalink
Add invariant enforcement support (#834)
Browse files Browse the repository at this point in the history
# Description

Adds support to retrieve invariants from the Delta schema and also a
struct `DeltaDataChecker` to use DataFusion to check them and report
useful errors.

This also hooks it up to the Python bindings, allowing
`write_deltalake()` to support Writer Protocol V2.

I looked briefly at the Rust writer, but then realized we don't want to
introduce a dependency on DataFusion. We should discuss how we want to
design that API. I suspect we'll turn DeltaDataChecker into a trait, so
we can have a DataFusion one available but also allow other engines to
implement it themselves if they don't wish to use DataFusion.

# Related Issue(s)

- closes #592
- closes #575

# Documentation


https://github.com/delta-io/delta/blob/master/PROTOCOL.md#column-invariants
  • Loading branch information
wjones127 authored Sep 28, 2022
1 parent 7824a37 commit e2cbc79
Show file tree
Hide file tree
Showing 12 changed files with 554 additions and 46 deletions.
2 changes: 1 addition & 1 deletion python/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,4 @@ features = ["extension-module", "abi3", "abi3-py37"]
[dependencies.deltalake]
path = "../rust"
version = "0"
features = ["s3", "azure", "glue", "gcs", "python"]
features = ["s3", "azure", "glue", "gcs", "python", "datafusion-ext"]
7 changes: 6 additions & 1 deletion python/deltalake/_internal.pyi
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import sys
from typing import Any, Callable, Dict, List, Mapping, Optional, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union

if sys.version_info >= (3, 8):
from typing import Literal
Expand Down Expand Up @@ -118,6 +118,7 @@ class StructType:
class Schema:
def __init__(self, fields: List[Field]) -> None: ...
fields: List[Field]
invariants: List[Tuple[str, str]]

def to_json(self) -> str: ...
@staticmethod
Expand Down Expand Up @@ -212,3 +213,7 @@ class DeltaFileSystemHandler:
self, path: str, metadata: dict[str, str] | None = None
) -> ObjectOutputStream:
"""Open an output stream for sequential writing."""

class DeltaDataChecker:
def __init__(self, invariants: List[Tuple[str, str]]) -> None: ...
def check_batch(self, batch: pa.RecordBatch) -> None: ...
28 changes: 26 additions & 2 deletions python/deltalake/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import pyarrow.fs as pa_fs
from pyarrow.lib import RecordBatchReader

from ._internal import DeltaDataChecker as _DeltaDataChecker
from ._internal import PyDeltaTableError
from ._internal import write_new_deltalake as _write_new_deltalake
from .table import DeltaTable
Expand Down Expand Up @@ -192,11 +193,11 @@ def write_deltalake(
if partition_by:
assert partition_by == table.metadata().partition_columns

if table.protocol().min_writer_version > 1:
if table.protocol().min_writer_version > 2:
raise DeltaTableProtocolError(
"This table's min_writer_version is "
f"{table.protocol().min_writer_version}, "
"but this method only supports version 1."
"but this method only supports version 2."
)
else: # creating a new table
current_version = -1
Expand Down Expand Up @@ -234,6 +235,29 @@ def visitor(written_file: Any) -> None:
)
)

if table is not None:
# We don't currently provide a way to set invariants
# (and maybe never will), so only enforce if already exist.
invariants = table.schema().invariants
checker = _DeltaDataChecker(invariants)

def validate_batch(batch: pa.RecordBatch) -> pa.RecordBatch:
checker.check_batch(batch)
return batch

if isinstance(data, RecordBatchReader):
batch_iter = data
elif isinstance(data, pa.RecordBatch):
batch_iter = [data]
elif isinstance(data, pa.Table):
batch_iter = data.to_batches()
else:
batch_iter = data

data = RecordBatchReader.from_batches(
schema, (validate_batch(batch) for batch in batch_iter)
)

ds.write_dataset(
data,
base_dir="/",
Expand Down
38 changes: 37 additions & 1 deletion python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ use chrono::{DateTime, FixedOffset, Utc};
use deltalake::action::{
self, Action, ColumnCountStat, ColumnValueStat, DeltaOperation, SaveMode, Stats,
};
use deltalake::arrow::record_batch::RecordBatch;
use deltalake::arrow::{self, datatypes::Schema as ArrowSchema};
use deltalake::builder::DeltaTableBuilder;
use deltalake::delta_datafusion::DeltaDataChecker;
use deltalake::partitions::PartitionFilter;
use deltalake::DeltaDataTypeLong;
use deltalake::DeltaDataTypeTimestamp;
use deltalake::DeltaTableMetaData;
use deltalake::DeltaTransactionOptions;
use deltalake::Schema;
use deltalake::{Invariant, Schema};
use pyo3::create_exception;
use pyo3::exceptions::PyException;
use pyo3::exceptions::PyValueError;
Expand Down Expand Up @@ -585,6 +587,39 @@ fn write_new_deltalake(
Ok(())
}

#[pyclass(name = "DeltaDataChecker", text_signature = "(invariants)")]
struct PyDeltaDataChecker {
inner: DeltaDataChecker,
rt: tokio::runtime::Runtime,
}

#[pymethods]
impl PyDeltaDataChecker {
#[new]
fn new(invariants: Vec<(String, String)>) -> Self {
let invariants: Vec<Invariant> = invariants
.into_iter()
.map(|(field_name, invariant_sql)| Invariant {
field_name,
invariant_sql,
})
.collect();
Self {
inner: DeltaDataChecker::new(invariants),
rt: tokio::runtime::Runtime::new().unwrap(),
}
}

fn check_batch(&self, batch: RecordBatch) -> PyResult<()> {
self.rt.block_on(async {
self.inner
.check_batch(&batch)
.await
.map_err(PyDeltaTableError::from_raw)
})
}
}

#[pymodule]
// module name need to match project name
fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
Expand All @@ -594,6 +629,7 @@ fn _internal(py: Python, m: &PyModule) -> PyResult<()> {
m.add_function(pyo3::wrap_pyfunction!(write_new_deltalake, m)?)?;
m.add_class::<RawDeltaTable>()?;
m.add_class::<RawDeltaTableMetaData>()?;
m.add_class::<PyDeltaDataChecker>()?;
m.add("PyDeltaTableError", py.get_type::<PyDeltaTableError>())?;
// There are issues with submodules, so we will expose them flat for now
// See also: https://github.com/PyO3/pyo3/issues/759
Expand Down
18 changes: 18 additions & 0 deletions python/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,4 +1064,22 @@ impl PySchema {
Err(PyTypeError::new_err("Type is not a struct"))
}
}

/// The list of invariants on the table.
///
/// :rtype: List[Tuple[str, str]]
/// :return: a tuple of strings for each invariant. The first string is the
/// field path and the second is the SQL of the invariant.
#[getter]
fn invariants(self_: PyRef<'_, Self>) -> PyResult<Vec<(String, String)>> {
let super_ = self_.as_ref();
let invariants = super_
.inner_type
.get_invariants()
.map_err(|err| PyException::new_err(err.to_string()))?;
Ok(invariants
.into_iter()
.map(|invariant| (invariant.field_name, invariant.invariant_sql))
.collect())
}
}
115 changes: 115 additions & 0 deletions python/tests/pyspark_integration/test_write_to_pyspark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Tests that deltalake(delta-rs) can write to tables written by PySpark"""
import pathlib

import pyarrow as pa
import pytest

from deltalake import write_deltalake
from deltalake._internal import PyDeltaTableError
from deltalake.writer import DeltaTableProtocolError

from .utils import assert_spark_read_equal, get_spark

try:
import delta
import delta.pip_utils
import delta.tables
import pyspark

spark = get_spark()
except ModuleNotFoundError:
pass


@pytest.mark.pyspark
@pytest.mark.integration
def test_write_basic(tmp_path: pathlib.Path):
# Write table in Spark
spark = get_spark()
schema = pyspark.sql.types.StructType(
[
pyspark.sql.types.StructField(
"c1",
dataType=pyspark.sql.types.IntegerType(),
nullable=True,
)
]
)
spark.createDataFrame([(4,)], schema=schema).write.save(
str(tmp_path),
mode="append",
format="delta",
)
# Overwrite table in deltalake
data = pa.table({"c1": pa.array([5, 6], type=pa.int32())})
write_deltalake(str(tmp_path), data, mode="overwrite")

# Read table in Spark
assert_spark_read_equal(data, str(tmp_path), sort_by="c1")


@pytest.mark.pyspark
@pytest.mark.integration
def test_write_invariant(tmp_path: pathlib.Path):
# Write table in Spark with invariant
spark = get_spark()

schema = pyspark.sql.types.StructType(
[
pyspark.sql.types.StructField(
"c1",
dataType=pyspark.sql.types.IntegerType(),
nullable=True,
metadata={
"delta.invariants": '{"expression": { "expression": "c1 > 3"} }'
},
)
]
)

delta.tables.DeltaTable.create(spark).location(str(tmp_path)).addColumns(
schema
).execute()

spark.createDataFrame([(4,)], schema=schema).write.save(
str(tmp_path),
mode="append",
format="delta",
)

# Cannot write invalid data to the table
invalid_data = pa.table({"c1": pa.array([6, 2], type=pa.int32())})
with pytest.raises(
PyDeltaTableError, match="Invariant \(c1 > 3\) violated by value .+2"
):
# raise PyDeltaTableError("test")
write_deltalake(str(tmp_path), invalid_data, mode="overwrite")

# Can write valid data to the table
valid_data = pa.table({"c1": pa.array([5, 6], type=pa.int32())})
write_deltalake(str(tmp_path), valid_data, mode="append")

expected = pa.table({"c1": pa.array([4, 5, 6], type=pa.int32())})
assert_spark_read_equal(expected, str(tmp_path), sort_by="c1")


@pytest.mark.pyspark
@pytest.mark.integration
def test_checks_min_writer_version(tmp_path: pathlib.Path):
# Write table in Spark with constraint
spark = get_spark()

spark.createDataFrame([(4,)], schema=["c1"]).write.save(
str(tmp_path),
mode="append",
format="delta",
)

# Add a constraint upgrades the minWriterProtocol
spark.sql(f"ALTER TABLE delta.`{str(tmp_path)}` ADD CONSTRAINT x CHECK (c1 > 2)")

with pytest.raises(
DeltaTableProtocolError, match="This table's min_writer_version is 3, but"
):
valid_data = pa.table({"c1": pa.array([5, 6])})
write_deltalake(str(tmp_path), valid_data, mode="append")
38 changes: 1 addition & 37 deletions python/tests/pyspark_integration/test_writer_readable.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,7 @@

from deltalake import DeltaTable, write_deltalake

try:
from pandas.testing import assert_frame_equal
except ModuleNotFoundError:
_has_pandas = False
else:
_has_pandas = True


def get_spark():
builder = (
pyspark.sql.SparkSession.builder.appName("MyApp")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
)
)
return delta.pip_utils.configure_spark_with_delta_pip(builder).getOrCreate()

from .utils import assert_spark_read_equal, get_spark

try:
import delta
Expand All @@ -38,24 +20,6 @@ def get_spark():
pass


def assert_spark_read_equal(
expected: pa.Table, uri: str, sort_by: List[str] = ["int32"]
):
df = spark.read.format("delta").load(uri)

# Spark and pyarrow don't convert these types to the same Pandas values
incompatible_types = ["timestamp", "struct"]

assert_frame_equal(
df.toPandas()
.sort_values(sort_by, ignore_index=True)
.drop(incompatible_types, axis="columns"),
expected.to_pandas()
.sort_values(sort_by, ignore_index=True)
.drop(incompatible_types, axis="columns"),
)


@pytest.mark.pyspark
@pytest.mark.integration
def test_basic_read(sample_data: pa.Table, existing_table: DeltaTable):
Expand Down
49 changes: 49 additions & 0 deletions python/tests/pyspark_integration/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from typing import List

import pyarrow as pa

try:
import delta
import delta.pip_utils
import delta.tables
import pyspark
except ModuleNotFoundError:
pass

try:
from pandas.testing import assert_frame_equal
except ModuleNotFoundError:
_has_pandas = False
else:
_has_pandas = True


def get_spark():
builder = (
pyspark.sql.SparkSession.builder.appName("MyApp")
.config("spark.sql.extensions", "io.delta.sql.DeltaSparkSessionExtension")
.config(
"spark.sql.catalog.spark_catalog",
"org.apache.spark.sql.delta.catalog.DeltaCatalog",
)
)
return delta.pip_utils.configure_spark_with_delta_pip(builder).getOrCreate()


def assert_spark_read_equal(
expected: pa.Table, uri: str, sort_by: List[str] = ["int32"]
):
spark = get_spark()
df = spark.read.format("delta").load(uri)

# Spark and pyarrow don't convert these types to the same Pandas values
incompatible_types = ["timestamp", "struct"]

assert_frame_equal(
df.toPandas()
.sort_values(sort_by, ignore_index=True)
.drop(incompatible_types, axis="columns", errors="ignore"),
expected.to_pandas()
.sort_values(sort_by, ignore_index=True)
.drop(incompatible_types, axis="columns", errors="ignore"),
)
2 changes: 1 addition & 1 deletion python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,7 @@ def test_writer_null_stats(tmp_path: pathlib.Path):


def test_writer_fails_on_protocol(existing_table: DeltaTable, sample_data: pa.Table):
existing_table.protocol = Mock(return_value=ProtocolVersions(1, 2))
existing_table.protocol = Mock(return_value=ProtocolVersions(1, 3))
with pytest.raises(DeltaTableProtocolError):
write_deltalake(existing_table, sample_data, mode="overwrite")

Expand Down
Loading

0 comments on commit e2cbc79

Please sign in to comment.