Skip to content

Commit

Permalink
Extend the existing Python Delta Table API to expose WHEN NOT MATCHED…
Browse files Browse the repository at this point in the history
… BY SOURCE clause in merge commands.

Support for the clause was introduced in #1511 using the Scala Delta Table API, this patch extends the Python API to support the new clause.

See corresponding feature request: #1364

Adding python tests covering WHEN NOT MATCHED BY SOURCE to test_deltatable.py.

The extended API for NOT MATCHED BY SOURCE mirrors existing clauses (MATCHED/NOT MATCHED).
Usage:
```
        dt.merge(source, "key = k")
            .whenNotMatchedBySourceDelete(condition="value > 0")
            .whenNotMatchedBySourceUpdate(set={"value": "value + 0"})
            .execute()
```

Closes #1533

GitOrigin-RevId: 76c7aea481fdbbf47af36ef7251ed555749954ac
  • Loading branch information
johanl-db authored and scottsand-db committed Jan 5, 2023
1 parent c3e0a1a commit 9cb4dc4
Show file tree
Hide file tree
Showing 2 changed files with 189 additions and 8 deletions.
100 changes: 95 additions & 5 deletions python/delta/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,8 +710,8 @@ class DeltaMergeBuilder(object):
"""
Builder to specify how to merge data from source DataFrame into the target Delta table.
Use :py:meth:`delta.tables.DeltaTable.merge` to create an object of this class.
Using this builder, you can specify any number of ``whenMatched`` and ``whenNotMatched``
clauses. Here are the constraints on these clauses.
Using this builder, you can specify any number of ``whenMatched``, ``whenNotMatched`` and
``whenNotMatchedBySource`` clauses. Here are the constraints on these clauses.
- Constraints in the ``whenMatched`` clauses:
Expand All @@ -723,7 +723,7 @@ class DeltaMergeBuilder(object):
- When there are two ``whenMatched`` clauses and there are conditions (or the lack of)
such that a row matches both clauses, then the first clause/action is executed.
In other words, the order of the ``whenMatched`` clauses matter.
In other words, the order of the ``whenMatched`` clauses matters.
- If none of the ``whenMatched`` clauses match a source-target row pair that satisfy
the merge condition, then the target rows will not be updated or deleted.
Expand Down Expand Up @@ -755,6 +755,23 @@ class DeltaMergeBuilder(object):
... # for all columns in the delta table
})
- Constraints in the ``whenNotMatchedBySource`` clauses:
- Each ``whenNotMatchedBySource`` clause can have an optional condition. However, only the
last ``whenNotMatchedBySource`` clause may omit the condition.
- Conditions and update expressions in ``whenNotMatchedBySource`` clauses may only refer to
columns from the target Delta table.
- When there are more than one ``whenNotMatchedBySource`` clauses and there are conditions (or
the lack of) such that a row satisfies multiple clauses, then the first clause/action
satisfied is executed. In other words, the order of the ``whenNotMatchedBySource`` clauses
matters.
- If no ``whenNotMatchedBySource`` clause is present or if it is present but the
non-matching target row does not satisfy any of the ``whenNotMatchedBySource`` clause
condition, then the target row will not be updated or deleted.
Example 1 with conditions and update expressions as SQL formatted string::
deltaTable.alias("events").merge(
Expand All @@ -770,7 +787,12 @@ class DeltaMergeBuilder(object):
"date": "updates.date",
"eventId": "updates.eventId",
"data": "updates.data",
"count": "1"
"count": "1",
"missed_count": "0"
}
).whenNotMatchedBySourceUpdate(set =
{
"missed_count": "events.missed_count + 1"
}
).execute()
Expand All @@ -791,7 +813,12 @@ class DeltaMergeBuilder(object):
"date": col("updates.date"),
"eventId": col("updates.eventId"),
"data": col("updates.data"),
"count": lit("1")
"count": lit("1"),
"missed_count": lit("0")
}
).whenNotMatchedBySourceUpdate(set =
{
"missed_count": col("events.missed_count") + 1
}
).execute()
Expand Down Expand Up @@ -928,6 +955,60 @@ def whenNotMatchedInsertAll(
new_jbuilder = self.__getNotMatchedBuilder(condition).insertAll()
return DeltaMergeBuilder(self._spark, new_jbuilder)

@overload
def whenNotMatchedBySourceUpdate(
self, condition: OptionalExpressionOrColumn, set: ColumnMapping
) -> "DeltaMergeBuilder":
...

@overload
def whenNotMatchedBySourceUpdate(
self, *, set: ColumnMapping
) -> "DeltaMergeBuilder":
...

def whenNotMatchedBySourceUpdate(
self,
condition: OptionalExpressionOrColumn = None,
set: OptionalColumnMapping = None
) -> "DeltaMergeBuilder":
"""
Update a target row that has no matches in the source based on the rules defined by ``set``.
If a ``condition`` is specified, then it must evaluate to true for the row to be updated.
See :py:class:`~delta.tables.DeltaMergeBuilder` for complete usage details.
:param condition: Optional condition of the update
:type condition: str or pyspark.sql.Column
:param set: Defines the rules of setting the values of columns that need to be updated.
*Note: This param is required.* Default value None is present to allow
positional args in same order across languages.
:type set: dict with str as keys and str or pyspark.sql.Column as values
:return: this builder
.. versionadded:: 2.3
"""
jset = DeltaTable._dict_to_jmap(self._spark, set, "'set' in whenNotMatchedBySourceUpdate")
new_jbuilder = self.__getNotMatchedBySourceBuilder(condition).update(jset)
return DeltaMergeBuilder(self._spark, new_jbuilder)

@since(2.3) # type: ignore[arg-type]
def whenNotMatchedBySourceDelete(
self, condition: OptionalExpressionOrColumn = None
) -> "DeltaMergeBuilder":
"""
Delete a target row that has no matches in the source from the table only if the given
``condition`` (if specified) is true for the target row.
See :py:class:`~delta.tables.DeltaMergeBuilder` for complete usage details.
:param condition: Optional condition of the delete
:type condition: str or pyspark.sql.Column
:return: this builder
"""
new_jbuilder = self.__getNotMatchedBySourceBuilder(condition).delete()
return DeltaMergeBuilder(self._spark, new_jbuilder)

@since(0.4) # type: ignore[arg-type]
def execute(self) -> None:
"""
Expand All @@ -953,6 +1034,15 @@ def __getNotMatchedBuilder(
else:
return self._jbuilder.whenNotMatched(DeltaTable._condition_to_jcolumn(condition))

def __getNotMatchedBySourceBuilder(
self, condition: OptionalExpressionOrColumn = None
) -> "JavaObject":
if condition is None:
return self._jbuilder.whenNotMatchedBySource()
else:
return self._jbuilder.whenNotMatchedBySource(
DeltaTable._condition_to_jcolumn(condition))


class DeltaTableBuilder(object):
"""
Expand Down
97 changes: 94 additions & 3 deletions python/delta/tests/test_deltatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ def reset_table() -> None:
dt.merge(source, "key = k") \
.whenMatchedUpdate(set={"value": "v + 0"}) \
.whenNotMatchedInsert(values={"key": "k", "value": "v + 0"}) \
.whenNotMatchedBySourceUpdate(set={"value": "value + 0"}) \
.execute()
self.__checkAnswer(dt.toDF(),
([('a', -1), ('b', 0), ('c', 3), ('d', 4), ('e', -5), ('f', -6)]))
Expand All @@ -160,10 +161,47 @@ def reset_table() -> None:
dt.merge(source, expr("key = k")) \
.whenMatchedUpdate(set={"value": col("v") + 0}) \
.whenNotMatchedInsert(values={"key": "k", "value": col("v") + 0}) \
.whenNotMatchedBySourceUpdate(set={"value": col("value") + 0}) \
.execute()
self.__checkAnswer(dt.toDF(),
([('a', -1), ('b', 0), ('c', 3), ('d', 4), ('e', -5), ('f', -6)]))

# Multiple not matched by source update clauses
reset_table()
dt.merge(source, expr("key = k")) \
.whenNotMatchedBySourceUpdate(condition="key = 'c'", set={"value": "5"}) \
.whenNotMatchedBySourceUpdate(set={"value": "0"}) \
.execute()
self.__checkAnswer(dt.toDF(), ([('a', 1), ('b', 2), ('c', 5), ('d', 0)]))

# Multiple not matched by source delete clauses
reset_table()
dt.merge(source, expr("key = k")) \
.whenNotMatchedBySourceDelete(condition="key = 'c'") \
.whenNotMatchedBySourceDelete() \
.execute()
self.__checkAnswer(dt.toDF(), ([('a', 1), ('b', 2)]))

# Redundant not matched by source update and delete clauses
reset_table()
dt.merge(source, expr("key = k")) \
.whenNotMatchedBySourceUpdate(condition="key = 'c'", set={"value": "5"}) \
.whenNotMatchedBySourceUpdate(condition="key = 'c'", set={"value": "0"}) \
.whenNotMatchedBySourceUpdate(condition="key = 'd'", set={"value": "6"}) \
.whenNotMatchedBySourceDelete(condition="key = 'd'") \
.execute()
self.__checkAnswer(dt.toDF(), ([('a', 1), ('b', 2), ('c', 5), ('d', 6)]))

# Interleaved update and delete clauses
reset_table()
dt.merge(source, expr("key = k")) \
.whenNotMatchedBySourceDelete(condition="key = 'c'") \
.whenNotMatchedBySourceUpdate(condition="key = 'c'", set={"value": "5"}) \
.whenNotMatchedBySourceDelete(condition="key = 'd'") \
.whenNotMatchedBySourceUpdate(set={"value": "6"}) \
.execute()
self.__checkAnswer(dt.toDF(), ([('a', 1), ('b', 2)]))

# ============== Test clause conditions ==============

# String expressions in all conditions and dicts
Expand All @@ -172,8 +210,10 @@ def reset_table() -> None:
.whenMatchedUpdate(condition="k = 'a'", set={"value": "v + 0"}) \
.whenMatchedDelete(condition="k = 'b'") \
.whenNotMatchedInsert(condition="k = 'e'", values={"key": "k", "value": "v + 0"}) \
.whenNotMatchedBySourceUpdate(condition="key = 'c'", set={"value": col("value") + 0}) \
.whenNotMatchedBySourceDelete(condition="key = 'd'") \
.execute()
self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('d', 4), ('e', -5)]))
self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('e', -5)]))

# Column expressions in all conditions and dicts
reset_table()
Expand All @@ -185,17 +225,23 @@ def reset_table() -> None:
.whenNotMatchedInsert(
condition=expr("k = 'e'"),
values={"key": "k", "value": col("v") + 0}) \
.whenNotMatchedBySourceUpdate(
condition=expr("key = 'c'"),
set={"value": col("value") + 0}) \
.whenNotMatchedBySourceDelete(condition=expr("key = 'd'")) \
.execute()
self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('d', 4), ('e', -5)]))
self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('e', -5)]))

# Positional arguments
reset_table()
dt.merge(source, "key = k") \
.whenMatchedUpdate("k = 'a'", {"value": "v + 0"}) \
.whenMatchedDelete("k = 'b'") \
.whenNotMatchedInsert("k = 'e'", {"key": "k", "value": "v + 0"}) \
.whenNotMatchedBySourceUpdate("key = 'c'", {"value": "value + 0"}) \
.whenNotMatchedBySourceDelete("key = 'd'") \
.execute()
self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('d', 4), ('e', -5)]))
self.__checkAnswer(dt.toDF(), ([('a', -1), ('c', 3), ('e', -5)]))

# ============== Test updateAll/insertAll ==============

Expand Down Expand Up @@ -319,6 +365,51 @@ def reset_table() -> None:
.merge(source, "key = k")
.whenNotMatchedInsert(values="k = 'a'", condition={"value": 1}))

# ---- bad args in whenNotMatchedBySourceUpdate()
with self.assertRaisesRegex(ValueError, "cannot be None"):
(dt # type: ignore[call-overload]
.merge(source, "key = k")
.whenNotMatchedBySourceUpdate({"value": "value"}))

with self.assertRaisesRegex(ValueError, "cannot be None"):
(dt # type: ignore[call-overload]
.merge(source, "key = k")
.whenNotMatchedBySourceUpdate(1))

with self.assertRaisesRegex(ValueError, "cannot be None"):
(dt # type: ignore[call-overload]
.merge(source, "key = k")
.whenNotMatchedBySourceUpdate(condition="key = 'a'"))

with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"):
(dt # type: ignore[call-overload]
.merge(source, "key = k")
.whenNotMatchedBySourceUpdate(1, {"value": "value"}))

with self.assertRaisesRegex(TypeError, "must be a dict"):
(dt # type: ignore[call-overload]
.merge(source, "key = k")
.whenNotMatchedBySourceUpdate("key = 'a'", 1))

with self.assertRaisesRegex(TypeError, "Values of dict in .* must contain only"):
(dt
.merge(source, "key = k")
.whenNotMatchedBySourceUpdate(set={"value": 1})) # type: ignore[dict-item]

with self.assertRaisesRegex(TypeError, "Keys of dict in .* must contain only"):
(dt
.merge(source, "key = k")
.whenNotMatchedBySourceUpdate(set={1: ""})) # type: ignore[dict-item]

with self.assertRaises(TypeError):
(dt # type: ignore[call-overload]
.merge(source, "key = k")
.whenNotMatchedBySourceUpdate(set="key = 'a'", condition={"value": 1}))

# bad args in whenNotMatchedBySourceDelete()
with self.assertRaisesRegex(TypeError, "must be a Spark SQL Column or a string"):
dt.merge(source, "key = k").whenNotMatchedBySourceDelete(1) # type: ignore[arg-type]

def test_history(self) -> None:
self.__writeDeltaTable([('a', 1), ('b', 2), ('c', 3)])
self.__overwriteDeltaTable([('a', 3), ('b', 2), ('c', 1)])
Expand Down

0 comments on commit 9cb4dc4

Please sign in to comment.