diff --git a/python/delta/tables.py b/python/delta/tables.py index f9824eaa9d1..83ef64408ba 100644 --- a/python/delta/tables.py +++ b/python/delta/tables.py @@ -14,6 +14,7 @@ # limitations under the License. # +from dataclasses import dataclass from typing import ( TYPE_CHECKING, cast, overload, Any, Dict, Iterable, Optional, Union, NoReturn, List, Tuple ) @@ -1060,6 +1061,19 @@ def __getNotMatchedBySourceBuilder( DeltaTable._condition_to_jcolumn(condition)) +@dataclass +class IdentityGenerator: + """ + Identity generator specifications for the identity column in the Delta table. + :param start: the start for the identity column. Default is 1. + :type start: int + :param step: the step for the identity column. Default is 1. + :type step: int + """ + start: int = 1 + step: int = 1 + + class DeltaTableBuilder(object): """ Builder to specify how to create / replace a Delta table. @@ -1108,6 +1122,10 @@ def _raise_type_error(self, msg: str, objs: Iterable[Any]) -> NoReturn: errorMsg += " Found %s with type %s" % ((str(obj)), str(type(obj))) raise TypeError(errorMsg) + def _check_identity_column_spec(self, identityGenerator: IdentityGenerator) -> None: + if identityGenerator.step == 0: + raise ValueError("Column identity generation requires step to be non-zero.") + @since(1.0) # type: ignore[arg-type] def tableName(self, identifier: str) -> "DeltaTableBuilder": """ @@ -1164,7 +1182,8 @@ def addColumn( colName: str, dataType: Union[str, DataType], nullable: bool = True, - generatedAlwaysAs: Optional[str] = None, + generatedAlwaysAs: Optional[Union[str, IdentityGenerator]] = None, + generatedByDefaultAs: Optional[IdentityGenerator] = None, comment: Optional[str] = None, ) -> "DeltaTableBuilder": """ @@ -1177,9 +1196,15 @@ def addColumn( :param nullable: whether column is nullable :type nullable: bool :param generatedAlwaysAs: a SQL expression if the column is always generated - as a function of other columns. + as a function of other columns; + an IdentityGenerator object if the column is always + generated using identity generator See online documentation for details on Generated Columns. - :type generatedAlwaysAs: str + :type generatedAlwaysAs: str or delta.tables.IdentityGenerator + :param generatedByDefaultAs: an IdentityGenerator object to generate identity values + if the user does not provide values for the column + See online documentation for details on Generated Columns. + :type generatedByDefaultAs: delta.tables.IdentityGenerator :param comment: the column comment :type comment: str @@ -1203,11 +1228,31 @@ def addColumn( if type(nullable) is not bool: self._raise_type_error("Column nullable must be bool.", [nullable]) _col_jbuilder = _col_jbuilder.nullable(nullable) + + if generatedAlwaysAs is not None and generatedByDefaultAs is not None: + raise ValueError( + "generatedByDefaultAs and generatedAlwaysAs cannot both be set.", + [generatedByDefaultAs, generatedAlwaysAs]) if generatedAlwaysAs is not None: - if type(generatedAlwaysAs) is not str: - self._raise_type_error("Column generation expression must be str.", - [generatedAlwaysAs]) - _col_jbuilder = _col_jbuilder.generatedAlwaysAs(generatedAlwaysAs) + if type(generatedAlwaysAs) is str: + _col_jbuilder = _col_jbuilder.generatedAlwaysAs(generatedAlwaysAs) + elif isinstance(generatedAlwaysAs, IdentityGenerator): + self._check_identity_column_spec(generatedAlwaysAs) + _col_jbuilder = _col_jbuilder.generatedAlwaysAsIdentity( + generatedAlwaysAs.start, generatedAlwaysAs.step) + else: + self._raise_type_error( + "Generated always as expression must be str or IdentityGenerator.", + [generatedAlwaysAs]) + elif generatedByDefaultAs is not None: + if not isinstance(generatedByDefaultAs, IdentityGenerator): + self._raise_type_error( + "Generated by default expression must be IdentityGenerator.", + [generatedByDefaultAs]) + self._check_identity_column_spec(generatedByDefaultAs) + _col_jbuilder = _col_jbuilder.generatedByDefaultAsIdentity( + generatedByDefaultAs.start, generatedByDefaultAs.step) + if comment is not None: if type(comment) is not str: self._raise_type_error("Column comment must be str.", [comment]) diff --git a/python/delta/tests/test_deltatable.py b/python/delta/tests/test_deltatable.py index 28b960e57d1..ee8dcc8a284 100644 --- a/python/delta/tests/test_deltatable.py +++ b/python/delta/tests/test_deltatable.py @@ -23,12 +23,13 @@ from multiprocessing.pool import ThreadPool from typing import List, Set, Dict, Optional, Any, Callable, Union, Tuple +from pyspark.errors.exceptions.base import UnsupportedOperationException from pyspark.sql import DataFrame, Row from pyspark.sql.functions import col, lit, expr, floor from pyspark.sql.types import StructType, StructField, StringType, IntegerType, LongType, DataType from pyspark.sql.utils import AnalysisException, ParseException -from delta.tables import DeltaTable, DeltaTableBuilder, DeltaOptimizeBuilder +from delta.tables import DeltaTable, DeltaTableBuilder, DeltaOptimizeBuilder, IdentityGenerator from delta.testing.utils import DeltaTestCase @@ -638,6 +639,16 @@ def __verify_generated_column(self, tableName: str, deltaTable: DeltaTable) -> N deltaTable.update(expr("col2 = 11"), {"col1": expr("2")}) self.__checkAnswer(deltaTable.toDF(), [(2, 12)], schema=["col1", "col2"]) + def __verify_identity_column(self, tableName: str, deltaTable: DeltaTable) -> None: + for i in range(2): + cmd = "INSERT INTO {table} (val) VALUES ({i})".format(table=tableName, i=i) + self.spark.sql(cmd) + cmd = "INSERT INTO {table} (id3, val) VALUES (8, 2)".format(table=tableName) + self.spark.sql(cmd) + self.__checkAnswer(deltaTable.toDF(), + expectedAnswer=[(1, 2, 2, 0), (2, 3, 4, 1), (3, 4, 8, 2)], + schema=["id1", "id2", "id3", "val"]) + def __build_delta_table(self, builder: DeltaTableBuilder) -> DeltaTable: return builder.addColumn("col1", "int", comment="foo", nullable=False) \ .addColumn("col2", IntegerType(), generatedAlwaysAs="col1 + 10") \ @@ -941,6 +952,41 @@ def test_verify_paritionedBy_compatibility(self) -> None: partitioningColumns=["col1"], tblComment="comment") + def test_create_table_with_identity_column(self) -> None: + for ifNotExists in (False, True): + tableName = "testTable{}".format(ifNotExists) + with self.table(tableName): + try: + self.spark.conf.set("spark.databricks.delta.identityColumn.enabled", "true") + builder = ( + DeltaTable.createIfNotExists(self.spark) + if ifNotExists + else DeltaTable.create(self.spark)) + builder = builder.tableName(tableName) + builder = ( + builder.addColumn( + "id1", LongType(), generatedAlwaysAs=IdentityGenerator()) + .addColumn( + "id2", + "BIGINT", + generatedAlwaysAs=IdentityGenerator(start=2)) + .addColumn( + "id3", + "bigint", + generatedByDefaultAs=IdentityGenerator(start=2, step=2)) + .addColumn("val", "bigint", nullable=False)) + + deltaTable = builder.execute() + self.__verify_table_schema( + tableName, + deltaTable.toDF().schema, + ["id1", "id2", "id3", "val"], + [LongType(), LongType(), LongType(), LongType()], + nullables={"id1", "id2", "id3"}) + self.__verify_identity_column(tableName, deltaTable) + finally: + self.spark.conf.unset("spark.databricks.delta.identityColumn.enabled") + def test_delta_table_builder_with_bad_args(self) -> None: builder = DeltaTable.create(self.spark).location(self.tempFile) @@ -964,11 +1010,14 @@ def test_delta_table_builder_with_bad_args(self) -> None: with self.assertRaises(TypeError): builder.addColumn("a", 1) # type: ignore[arg-type] - # bad column datatype - can't be pared + # bad column datatype - can't be parsed with self.assertRaises(ParseException): builder.addColumn("a", "1") builder.execute() + # reset the builder + builder = DeltaTable.create(self.spark).location(self.tempFile) + # bad comment with self.assertRaises(TypeError): builder.addColumn("a", "int", comment=1) # type: ignore[arg-type] @@ -977,6 +1026,55 @@ def test_delta_table_builder_with_bad_args(self) -> None: with self.assertRaises(TypeError): builder.addColumn("a", "int", generatedAlwaysAs=1) # type: ignore[arg-type] + # bad generatedAlwaysAs - identity column data type must be Long + with self.assertRaises(UnsupportedOperationException): + builder.addColumn( + "a", + "int", + generatedAlwaysAs=IdentityGenerator() + ) # type: ignore[arg-type] + + # bad generatedAlwaysAs - step can't be 0 + with self.assertRaises(ValueError): + builder.addColumn( + "a", + "bigint", + generatedAlwaysAs=IdentityGenerator(step=0) + ) # type: ignore[arg-type] + + # bad generatedByDefaultAs - can't be set with generatedAlwaysAs + with self.assertRaises(ValueError): + builder.addColumn( + "a", + "bigint", + generatedAlwaysAs="", + generatedByDefaultAs=IdentityGenerator() + ) # type: ignore[arg-type] + + # bad generatedByDefaultAs - argument type must be IdentityGenerator + with self.assertRaises(TypeError): + builder.addColumn( + "a", + "bigint", + generatedByDefaultAs="" + ) # type: ignore[arg-type] + + # bad generatedByDefaultAs - identity column data type must be Long + with self.assertRaises(UnsupportedOperationException): + builder.addColumn( + "a", + "int", + generatedByDefaultAs=IdentityGenerator() + ) # type: ignore[arg-type] + + # bad generatedByDefaultAs - step can't be 0 + with self.assertRaises(ValueError): + builder.addColumn( + "a", + "bigint", + generatedByDefaultAs=IdentityGenerator(step=0) + ) # type: ignore[arg-type] + # bad nullable with self.assertRaises(TypeError): builder.addColumn("a", "int", nullable=1) # type: ignore[arg-type]