Skip to content

Commit

Permalink
[SPARK-50291][PYTHON] Standardize verifySchema parameter of createDat…
Browse files Browse the repository at this point in the history
…aFrame in Spark Classic

### What changes were proposed in this pull request?
The PR targets at Spark Classic only. Spark Connect will be handled in a follow-up PR.

`verifySchema` parameter of createDataFrame decides whether to verify data types of every row against schema.

Now it only takes effect for with createDataFrame with
- egular Python instances

The PR proposes to make it work with createDataFrame with
- `pyarrow.Table`
- `pandas.DataFrame` with Arrow optimization
- `pandas.DataFrame` without Arrow optimization

By default, `verifySchema` parameter is `pyspark._NoValue`, if not provided, createDataFrame with
- `pyarrow.Table`,  **verifySchema = False**
- `pandas.DataFrame` with Arrow optimization,  **verifySchema = spark.sql.execution.pandas.convertToArrowArraySafely**
- `pandas.DataFrame` without Arrow optimization, **verifySchema = True**
-  regular Python instances, **verifySchema = True** (existing behavior)

### Why are the changes needed?
The change makes schema validation consistent across all formats, improving data integrity and helping prevent errors.
It also enhances flexibility by allowing users to choose schema verification regardless of the input type.

Part of [SPARK-50146](https://issues.apache.org/jira/browse/SPARK-50146).

### Does this PR introduce _any_ user-facing change?
Setup:
```py
>>> import pyarrow as pa
>>> import pandas as pd
>>> from pyspark.sql.types import *
>>>
>>> data = {
...     "id": [1, 2, 3],
...     "value": [100000000000, 200000000000, 300000000000]
... }
>>> schema = StructType([StructField("id", IntegerType(), True), StructField("value", IntegerType(), True)])
```

Usage - createDataFrame with `pyarrow.Table`
```py
>>> table = pa.table(data)
>>> spark.createDataFrame(table, schema=schema).show()  # verifySchema defaults to False
+---+-----------+
| id|      value|
+---+-----------+
|  1| 1215752192|
|  2|-1863462912|
|  3| -647710720|
+---+-----------+

>>> spark.createDataFrame(table, schema=schema, verifySchema=True).show()
...
pyarrow.lib.ArrowInvalid: Integer value 100000000000 not in range: -2147483648 to 2147483647
```

Usage - createDataFrame with `pandas.DataFrame` without Arrow optimization

```py
>>> pdf = pd.DataFrame(data)
>>> spark.createDataFrame(pdf, schema=schema).show()  # verifySchema defaults to True
...
pyspark.errors.exceptions.base.PySparkValueError: [VALUE_OUT_OF_BOUNDS] Value for `obj` must be between -2147483648 and 2147483647 (inclusive), got 100000000000
>>> spark.createDataFrame(table, schema=schema, verifySchema=False).show()
+---+-----------+
| id|      value|
+---+-----------+
|  1| 1215752192|
|  2|-1863462912|
|  3| -647710720|
+---+-----------+
```

Usage - createDataFrame with `pandas.DataFrame` with Arrow optimization

```py

>>> pdf = pd.DataFrame(data)
>>> spark.conf.set("spark.sql.execution.arrow.pyspark.enabled", True)
>>> spark.conf.get("spark.sql.execution.pandas.convertToArrowArraySafely")
'false'
>>> spark.createDataFrame(pdf, schema=schema).show()  # verifySchema defaults to "spark.sql.execution.pandas.convertToArrowArraySafely"
+---+-----------+
| id|      value|
+---+-----------+
|  1| 1215752192|
|  2|-1863462912|
|  3| -647710720|
+---+-----------+

>>> spark.conf.set("spark.sql.execution.pandas.convertToArrowArraySafely", True)
>>> spark.createDataFrame(pdf, schema=schema).show()
...
pyspark.errors.exceptions.base.PySparkValueError: [VALUE_OUT_OF_BOUNDS] Value for `obj` must be between -2147483648 and 2147483647 (inclusive), got 100000000000

>>> spark.createDataFrame(table, schema=schema, verifySchema=True).show()
...
pyarrow.lib.ArrowInvalid: Integer value 100000000000 not in range: -2147483648 to 2147483647
```

### How was this patch tested?
Unit tests.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48677 from xinrong-meng/arrowSafe.

Authored-by: Xinrong Meng <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
xinrong-meng authored and HyukjinKwon committed Nov 14, 2024
1 parent 0b1b676 commit aea9e87
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 29 deletions.
37 changes: 27 additions & 10 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from warnings import warn

from pyspark._globals import _NoValue, _NoValueType
from pyspark.errors.exceptions.captured import unwrap_spark_exception
from pyspark.loose_version import LooseVersion
from pyspark.util import _load_from_socket
Expand Down Expand Up @@ -352,7 +353,7 @@ def createDataFrame(
self,
data: "PandasDataFrameLike",
schema: Union[StructType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> "DataFrame":
...

Expand All @@ -361,7 +362,7 @@ def createDataFrame(
self,
data: "pa.Table",
schema: Union[StructType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> "DataFrame":
...

Expand All @@ -370,7 +371,7 @@ def createDataFrame( # type: ignore[misc]
data: Union["PandasDataFrameLike", "pa.Table"],
schema: Optional[Union[StructType, List[str]]] = None,
samplingRatio: Optional[float] = None,
verifySchema: bool = True,
verifySchema: Union[_NoValueType, bool] = _NoValue,
) -> "DataFrame":
from pyspark.sql import SparkSession

Expand All @@ -392,7 +393,7 @@ def createDataFrame( # type: ignore[misc]
if schema is None:
schema = data.schema.names

return self._create_from_arrow_table(data, schema, timezone)
return self._create_from_arrow_table(data, schema, timezone, verifySchema)

# `data` is a PandasDataFrameLike object
from pyspark.sql.pandas.utils import require_minimum_pandas_version
Expand All @@ -405,7 +406,7 @@ def createDataFrame( # type: ignore[misc]

if self._jconf.arrowPySparkEnabled() and len(data) > 0:
try:
return self._create_from_pandas_with_arrow(data, schema, timezone)
return self._create_from_pandas_with_arrow(data, schema, timezone, verifySchema)
except Exception as e:
if self._jconf.arrowPySparkFallbackEnabled():
msg = (
Expand Down Expand Up @@ -624,7 +625,11 @@ def _get_numpy_record_dtype(self, rec: "np.recarray") -> Optional["np.dtype"]:
return np.dtype(record_type_list) if has_rec_fix else None

def _create_from_pandas_with_arrow(
self, pdf: "PandasDataFrameLike", schema: Union[StructType, List[str]], timezone: str
self,
pdf: "PandasDataFrameLike",
schema: Union[StructType, List[str]],
timezone: str,
verifySchema: Union[_NoValueType, bool],
) -> "DataFrame":
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
Expand Down Expand Up @@ -657,6 +662,10 @@ def _create_from_pandas_with_arrow(
)
import pyarrow as pa

if verifySchema is _NoValue:
# (With Arrow optimization) createDataFrame with `pandas.DataFrame`
verifySchema = self._jconf.arrowSafeTypeConversion()

infer_pandas_dict_as_map = (
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower() == "true"
)
Expand Down Expand Up @@ -725,8 +734,7 @@ def _create_from_pandas_with_arrow(

jsparkSession = self._jsparkSession

safecheck = self._jconf.arrowSafeTypeConversion()
ser = ArrowStreamPandasSerializer(timezone, safecheck)
ser = ArrowStreamPandasSerializer(timezone, verifySchema)

@no_type_check
def reader_func(temp_filename):
Expand All @@ -745,7 +753,11 @@ def create_iter_server():
return df

def _create_from_arrow_table(
self, table: "pa.Table", schema: Union[StructType, List[str]], timezone: str
self,
table: "pa.Table",
schema: Union[StructType, List[str]],
timezone: str,
verifySchema: Union[_NoValueType, bool],
) -> "DataFrame":
"""
Create a DataFrame from a given pyarrow.Table by slicing it into partitions then
Expand All @@ -767,6 +779,10 @@ def _create_from_arrow_table(

require_minimum_pyarrow_version()

if verifySchema is _NoValue:
# createDataFrame with `pyarrow.Table`
verifySchema = False

prefer_timestamp_ntz = is_timestamp_ntz_preferred()

# Create the Spark schema from list of names passed in with Arrow types
Expand All @@ -786,7 +802,8 @@ def _create_from_arrow_table(
schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=prefer_timestamp_ntz)

table = _check_arrow_table_timestamps_localize(table, schema, True, timezone).cast(
to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True)
to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True),
safe=verifySchema,
)

# Chunk the Arrow Table into RecordBatches
Expand Down
35 changes: 22 additions & 13 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TYPE_CHECKING,
)

from pyspark._globals import _NoValue, _NoValueType
from pyspark.conf import SparkConf
from pyspark.util import is_remote_only
from pyspark.sql.conf import RuntimeConfig
Expand Down Expand Up @@ -1265,7 +1266,7 @@ def createDataFrame(
data: Iterable["RowLike"],
schema: Union[StructType, str],
*,
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1275,7 +1276,7 @@ def createDataFrame(
data: "RDD[RowLike]",
schema: Union[StructType, str],
*,
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1284,7 +1285,7 @@ def createDataFrame(
self,
data: "RDD[AtomicValue]",
schema: Union[AtomicType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1293,7 +1294,7 @@ def createDataFrame(
self,
data: Iterable["AtomicValue"],
schema: Union[AtomicType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1312,7 +1313,7 @@ def createDataFrame(
self,
data: "PandasDataFrameLike",
schema: Union[StructType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1321,7 +1322,7 @@ def createDataFrame(
self,
data: "pa.Table",
schema: Union[StructType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1330,7 +1331,7 @@ def createDataFrame( # type: ignore[misc]
data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike", "ArrayLike", "pa.Table"],
schema: Optional[Union[AtomicType, StructType, str]] = None,
samplingRatio: Optional[float] = None,
verifySchema: bool = True,
verifySchema: Union[_NoValueType, bool] = _NoValue,
) -> DataFrame:
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list, a :class:`pandas.DataFrame`,
Expand Down Expand Up @@ -1374,11 +1375,14 @@ def createDataFrame( # type: ignore[misc]
if ``samplingRatio`` is ``None``. This option is effective only when the input is
:class:`RDD`.
verifySchema : bool, optional
verify data types of every row against schema. Enabled by default.
When the input is :class:`pyarrow.Table` or when the input class is
:class:`pandas.DataFrame` and `spark.sql.execution.arrow.pyspark.enabled` is enabled,
this option is not effective. It follows Arrow type coercion. This option is not
supported with Spark Connect.
verify data types of every row against schema.
If not provided, createDataFrame with
- pyarrow.Table, verifySchema=False
- pandas.DataFrame with Arrow optimization, verifySchema defaults to
`spark.sql.execution.pandas.convertToArrowArraySafely`
- pandas.DataFrame without Arrow optimization, verifySchema=True
- regular Python instances, verifySchema=True
Arrow optimization is enabled/disabled via `spark.sql.execution.arrow.pyspark.enabled`.
.. versionadded:: 2.1.0
Expand Down Expand Up @@ -1578,8 +1582,13 @@ def _create_dataframe(
data: Union["RDD[Any]", Iterable[Any]],
schema: Optional[Union[DataType, List[str]]],
samplingRatio: Optional[float],
verifySchema: bool,
verifySchema: Union[_NoValueType, bool],
) -> DataFrame:
if verifySchema is _NoValue:
# createDataFrame with regular Python instances
# or (without Arrow optimization) createDataFrame with Pandas DataFrame
verifySchema = True

if isinstance(schema, StructType):
verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True

Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def test_toPandas_udt(self):
def test_create_dataframe_namedtuples(self):
self.check_create_dataframe_namedtuples(True)

@unittest.skip("Spark Connect does not support verifySchema.")
def test_createDataFrame_verifySchema(self):
super().test_createDataFrame_verifySchema()


if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow import * # noqa: F401
Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,45 @@ def test_createDataFrame_arrow_pandas(self):
df_pandas = self.spark.createDataFrame(pdf)
self.assertEqual(df_arrow.collect(), df_pandas.collect())

def test_createDataFrame_verifySchema(self):
data = {"id": [1, 2, 3], "value": [100000000000, 200000000000, 300000000000]}
# data.value should fail schema validation when verifySchema is True
schema = StructType(
[StructField("id", IntegerType(), True), StructField("value", IntegerType(), True)]
)
expected = [
Row(id=1, value=1215752192),
Row(id=2, value=-1863462912),
Row(id=3, value=-647710720),
]
# Arrow table
table = pa.table(data)
df = self.spark.createDataFrame(table, schema=schema)
self.assertEqual(df.collect(), expected)

with self.assertRaises(Exception):
self.spark.createDataFrame(table, schema=schema, verifySchema=True)

# pandas DataFrame with Arrow optimization
pdf = pd.DataFrame(data)
df = self.spark.createDataFrame(pdf, schema=schema)
# verifySchema defaults to `spark.sql.execution.pandas.convertToArrowArraySafely`,
# which is false by default
self.assertEqual(df.collect(), expected)
with self.assertRaises(Exception):
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": True}):
df = self.spark.createDataFrame(pdf, schema=schema)
with self.assertRaises(Exception):
df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=True)

# pandas DataFrame without Arrow optimization
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
pdf = pd.DataFrame(data)
with self.assertRaises(Exception):
df = self.spark.createDataFrame(pdf, schema=schema) # verifySchema defaults to True
df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=False)
self.assertEqual(df.collect(), expected)

def _createDataFrame_toggle(self, data, schema=None):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df_no_arrow = self.spark.createDataFrame(data, schema=schema)
Expand Down
13 changes: 7 additions & 6 deletions python/pyspark/sql/tests/typing/test_session.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

- case: createDataFrameStructsValid
main: |
from pyspark._globals import _NoValueType
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
Expand Down Expand Up @@ -78,14 +79,14 @@
main:18: note: Possible overload variants:
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: Iterable[RowLike], schema: Union[list[str], tuple[str, ...]] = ..., samplingRatio: Optional[float] = ...) -> DataFrame
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: RDD[RowLike], schema: Union[list[str], tuple[str, ...]] = ..., samplingRatio: Optional[float] = ...) -> DataFrame
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: Iterable[RowLike], schema: Union[StructType, str], *, verifySchema: bool = ...) -> DataFrame
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: RDD[RowLike], schema: Union[StructType, str], *, verifySchema: bool = ...) -> DataFrame
main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: RDD[AtomicValue], schema: Union[AtomicType, str], verifySchema: bool = ...) -> DataFrame
main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: Iterable[AtomicValue], schema: Union[AtomicType, str], verifySchema: bool = ...) -> DataFrame
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: Iterable[RowLike], schema: Union[StructType, str], *, verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: RDD[RowLike], schema: Union[StructType, str], *, verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame
main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: RDD[AtomicValue], schema: Union[AtomicType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame
main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: Iterable[AtomicValue], schema: Union[AtomicType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: DataFrame, samplingRatio: Optional[float] = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: Any, samplingRatio: Optional[float] = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: DataFrame, schema: Union[StructType, str], verifySchema: bool = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: Any, schema: Union[StructType, str], verifySchema: bool = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: DataFrame, schema: Union[StructType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: Any, schema: Union[StructType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame
- case: createDataFrameFromEmptyRdd
main: |
Expand Down

0 comments on commit aea9e87

Please sign in to comment.