Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50291][PYTHON] Standardize verifySchema parameter of createDataFrame in Spark Classic #48677

Closed
wants to merge 15 commits into from
37 changes: 27 additions & 10 deletions python/pyspark/sql/pandas/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
)
from warnings import warn

from pyspark._globals import _NoValue, _NoValueType
from pyspark.errors.exceptions.captured import unwrap_spark_exception
from pyspark.loose_version import LooseVersion
from pyspark.util import _load_from_socket
Expand Down Expand Up @@ -352,7 +353,7 @@ def createDataFrame(
self,
data: "PandasDataFrameLike",
schema: Union[StructType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> "DataFrame":
...

Expand All @@ -361,7 +362,7 @@ def createDataFrame(
self,
data: "pa.Table",
schema: Union[StructType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> "DataFrame":
...

Expand All @@ -370,7 +371,7 @@ def createDataFrame( # type: ignore[misc]
data: Union["PandasDataFrameLike", "pa.Table"],
schema: Optional[Union[StructType, List[str]]] = None,
samplingRatio: Optional[float] = None,
verifySchema: bool = True,
verifySchema: Union[_NoValueType, bool] = _NoValue,
) -> "DataFrame":
from pyspark.sql import SparkSession

Expand All @@ -392,7 +393,7 @@ def createDataFrame( # type: ignore[misc]
if schema is None:
schema = data.schema.names

return self._create_from_arrow_table(data, schema, timezone)
return self._create_from_arrow_table(data, schema, timezone, verifySchema)

# `data` is a PandasDataFrameLike object
from pyspark.sql.pandas.utils import require_minimum_pandas_version
Expand All @@ -405,7 +406,7 @@ def createDataFrame( # type: ignore[misc]

if self._jconf.arrowPySparkEnabled() and len(data) > 0:
try:
return self._create_from_pandas_with_arrow(data, schema, timezone)
return self._create_from_pandas_with_arrow(data, schema, timezone, verifySchema)
except Exception as e:
if self._jconf.arrowPySparkFallbackEnabled():
msg = (
Expand Down Expand Up @@ -624,7 +625,11 @@ def _get_numpy_record_dtype(self, rec: "np.recarray") -> Optional["np.dtype"]:
return np.dtype(record_type_list) if has_rec_fix else None

def _create_from_pandas_with_arrow(
self, pdf: "PandasDataFrameLike", schema: Union[StructType, List[str]], timezone: str
self,
pdf: "PandasDataFrameLike",
schema: Union[StructType, List[str]],
timezone: str,
verifySchema: Union[_NoValueType, bool],
) -> "DataFrame":
"""
Create a DataFrame from a given pandas.DataFrame by slicing it into partitions, converting
Expand Down Expand Up @@ -657,6 +662,10 @@ def _create_from_pandas_with_arrow(
)
import pyarrow as pa

if verifySchema is _NoValue:
# (With Arrow optimization) createDataFrame with `pandas.DataFrame`
verifySchema = self._jconf.arrowSafeTypeConversion()

infer_pandas_dict_as_map = (
str(self.conf.get("spark.sql.execution.pandas.inferPandasDictAsMap")).lower() == "true"
)
Expand Down Expand Up @@ -725,8 +734,7 @@ def _create_from_pandas_with_arrow(

jsparkSession = self._jsparkSession

safecheck = self._jconf.arrowSafeTypeConversion()
ser = ArrowStreamPandasSerializer(timezone, safecheck)
ser = ArrowStreamPandasSerializer(timezone, verifySchema)

@no_type_check
def reader_func(temp_filename):
Expand All @@ -745,7 +753,11 @@ def create_iter_server():
return df

def _create_from_arrow_table(
self, table: "pa.Table", schema: Union[StructType, List[str]], timezone: str
self,
table: "pa.Table",
schema: Union[StructType, List[str]],
timezone: str,
verifySchema: Union[_NoValueType, bool],
) -> "DataFrame":
"""
Create a DataFrame from a given pyarrow.Table by slicing it into partitions then
Expand All @@ -767,6 +779,10 @@ def _create_from_arrow_table(

require_minimum_pyarrow_version()

if verifySchema is _NoValue:
# createDataFrame with `pyarrow.Table`
verifySchema = False

prefer_timestamp_ntz = is_timestamp_ntz_preferred()

# Create the Spark schema from list of names passed in with Arrow types
Expand All @@ -786,7 +802,8 @@ def _create_from_arrow_table(
schema = from_arrow_schema(table.schema, prefer_timestamp_ntz=prefer_timestamp_ntz)

table = _check_arrow_table_timestamps_localize(table, schema, True, timezone).cast(
to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True)
to_arrow_schema(schema, error_on_duplicated_field_names_in_struct=True),
safe=verifySchema,
)

# Chunk the Arrow Table into RecordBatches
Expand Down
35 changes: 22 additions & 13 deletions python/pyspark/sql/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
TYPE_CHECKING,
)

from pyspark._globals import _NoValue, _NoValueType
from pyspark.conf import SparkConf
from pyspark.util import is_remote_only
from pyspark.sql.conf import RuntimeConfig
Expand Down Expand Up @@ -1265,7 +1266,7 @@ def createDataFrame(
data: Iterable["RowLike"],
schema: Union[StructType, str],
*,
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1275,7 +1276,7 @@ def createDataFrame(
data: "RDD[RowLike]",
schema: Union[StructType, str],
*,
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1284,7 +1285,7 @@ def createDataFrame(
self,
data: "RDD[AtomicValue]",
schema: Union[AtomicType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1293,7 +1294,7 @@ def createDataFrame(
self,
data: Iterable["AtomicValue"],
schema: Union[AtomicType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1312,7 +1313,7 @@ def createDataFrame(
self,
data: "PandasDataFrameLike",
schema: Union[StructType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1321,7 +1322,7 @@ def createDataFrame(
self,
data: "pa.Table",
schema: Union[StructType, str],
verifySchema: bool = ...,
verifySchema: Union[_NoValueType, bool] = ...,
) -> DataFrame:
...

Expand All @@ -1330,7 +1331,7 @@ def createDataFrame( # type: ignore[misc]
data: Union["RDD[Any]", Iterable[Any], "PandasDataFrameLike", "ArrayLike", "pa.Table"],
schema: Optional[Union[AtomicType, StructType, str]] = None,
samplingRatio: Optional[float] = None,
verifySchema: bool = True,
verifySchema: Union[_NoValueType, bool] = _NoValue,
) -> DataFrame:
"""
Creates a :class:`DataFrame` from an :class:`RDD`, a list, a :class:`pandas.DataFrame`,
Expand Down Expand Up @@ -1374,11 +1375,14 @@ def createDataFrame( # type: ignore[misc]
if ``samplingRatio`` is ``None``. This option is effective only when the input is
:class:`RDD`.
verifySchema : bool, optional
verify data types of every row against schema. Enabled by default.
When the input is :class:`pyarrow.Table` or when the input class is
:class:`pandas.DataFrame` and `spark.sql.execution.arrow.pyspark.enabled` is enabled,
this option is not effective. It follows Arrow type coercion. This option is not
supported with Spark Connect.
verify data types of every row against schema.
If not provided, createDataFrame with
- pyarrow.Table, verifySchema=False
- pandas.DataFrame with Arrow optimization, verifySchema defaults to
`spark.sql.execution.pandas.convertToArrowArraySafely`
- pandas.DataFrame without Arrow optimization, verifySchema=True
- regular Python instances, verifySchema=True
Arrow optimization is enabled/disabled via `spark.sql.execution.arrow.pyspark.enabled`.

.. versionadded:: 2.1.0

Expand Down Expand Up @@ -1578,8 +1582,13 @@ def _create_dataframe(
data: Union["RDD[Any]", Iterable[Any]],
schema: Optional[Union[DataType, List[str]]],
samplingRatio: Optional[float],
verifySchema: bool,
verifySchema: Union[_NoValueType, bool],
) -> DataFrame:
if verifySchema is _NoValue:
# createDataFrame with regular Python instances
# or (without Arrow optimization) createDataFrame with Pandas DataFrame
verifySchema = True

if isinstance(schema, StructType):
verify_func = _make_type_verifier(schema) if verifySchema else lambda _: True

Expand Down
4 changes: 4 additions & 0 deletions python/pyspark/sql/tests/connect/test_parity_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ def test_toPandas_udt(self):
def test_create_dataframe_namedtuples(self):
self.check_create_dataframe_namedtuples(True)

@unittest.skip("Spark Connect does not support verifySchema.")
def test_createDataFrame_verifySchema(self):
super().test_createDataFrame_verifySchema()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.



if __name__ == "__main__":
from pyspark.sql.tests.connect.test_parity_arrow import * # noqa: F401
Expand Down
39 changes: 39 additions & 0 deletions python/pyspark/sql/tests/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,45 @@ def test_createDataFrame_arrow_pandas(self):
df_pandas = self.spark.createDataFrame(pdf)
self.assertEqual(df_arrow.collect(), df_pandas.collect())

def test_createDataFrame_verifySchema(self):
data = {"id": [1, 2, 3], "value": [100000000000, 200000000000, 300000000000]}
# data.value should fail schema validation when verifySchema is True
schema = StructType(
[StructField("id", IntegerType(), True), StructField("value", IntegerType(), True)]
)
expected = [
Row(id=1, value=1215752192),
Row(id=2, value=-1863462912),
Row(id=3, value=-647710720),
]
# Arrow table
table = pa.table(data)
df = self.spark.createDataFrame(table, schema=schema)
self.assertEqual(df.collect(), expected)

with self.assertRaises(Exception):
self.spark.createDataFrame(table, schema=schema, verifySchema=True)

# pandas DataFrame with Arrow optimization
pdf = pd.DataFrame(data)
df = self.spark.createDataFrame(pdf, schema=schema)
# verifySchema defaults to `spark.sql.execution.pandas.convertToArrowArraySafely`,
# which is false by default
self.assertEqual(df.collect(), expected)
with self.assertRaises(Exception):
with self.sql_conf({"spark.sql.execution.pandas.convertToArrowArraySafely": True}):
df = self.spark.createDataFrame(pdf, schema=schema)
with self.assertRaises(Exception):
df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=True)

# pandas DataFrame without Arrow optimization
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
pdf = pd.DataFrame(data)
with self.assertRaises(Exception):
df = self.spark.createDataFrame(pdf, schema=schema) # verifySchema defaults to True
df = self.spark.createDataFrame(pdf, schema=schema, verifySchema=False)
self.assertEqual(df.collect(), expected)

def _createDataFrame_toggle(self, data, schema=None):
with self.sql_conf({"spark.sql.execution.arrow.pyspark.enabled": False}):
df_no_arrow = self.spark.createDataFrame(data, schema=schema)
Expand Down
13 changes: 7 additions & 6 deletions python/pyspark/sql/tests/typing/test_session.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

- case: createDataFrameStructsValid
main: |
from pyspark._globals import _NoValueType
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType, IntegerType

Expand Down Expand Up @@ -78,14 +79,14 @@
main:18: note: Possible overload variants:
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: Iterable[RowLike], schema: Union[list[str], tuple[str, ...]] = ..., samplingRatio: Optional[float] = ...) -> DataFrame
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: RDD[RowLike], schema: Union[list[str], tuple[str, ...]] = ..., samplingRatio: Optional[float] = ...) -> DataFrame
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: Iterable[RowLike], schema: Union[StructType, str], *, verifySchema: bool = ...) -> DataFrame
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: RDD[RowLike], schema: Union[StructType, str], *, verifySchema: bool = ...) -> DataFrame
main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: RDD[AtomicValue], schema: Union[AtomicType, str], verifySchema: bool = ...) -> DataFrame
main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: Iterable[AtomicValue], schema: Union[AtomicType, str], verifySchema: bool = ...) -> DataFrame
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: Iterable[RowLike], schema: Union[StructType, str], *, verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame
main:18: note: def [RowLike in (list[Any], tuple[Any, ...], Row)] createDataFrame(self, data: RDD[RowLike], schema: Union[StructType, str], *, verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame
main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: RDD[AtomicValue], schema: Union[AtomicType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame
main:18: note: def [AtomicValue in (datetime, date, Decimal, bool, str, int, float)] createDataFrame(self, data: Iterable[AtomicValue], schema: Union[AtomicType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: DataFrame, samplingRatio: Optional[float] = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: Any, samplingRatio: Optional[float] = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: DataFrame, schema: Union[StructType, str], verifySchema: bool = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: Any, schema: Union[StructType, str], verifySchema: bool = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: DataFrame, schema: Union[StructType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame
main:18: note: def createDataFrame(self, data: Any, schema: Union[StructType, str], verifySchema: Union[_NoValueType, bool] = ...) -> DataFrame

- case: createDataFrameFromEmptyRdd
main: |
Expand Down