Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-44640][PYTHON][FOLLOW-UP][3.5] Update UDTF error messages to include method name #42840

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/pyspark/errors/error_classes.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@
},
"INVALID_ARROW_UDTF_RETURN_TYPE" : {
"message" : [
"The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the function returned a value of type <type_name> with value: <value>."
"The return type of the arrow-optimized Python UDTF should be of type 'pandas.DataFrame', but the '<func>' method returned a value of type <type_name> with value: <value>."
]
},
"INVALID_BROADCAST_OPERATION": {
Expand Down Expand Up @@ -730,17 +730,17 @@
},
"UDTF_INVALID_OUTPUT_ROW_TYPE" : {
"message" : [
"The type of an individual output row in the UDTF is invalid. Each row should be a tuple, list, or dict, but got '<type>'. Please make sure that the output rows are of the correct type."
"The type of an individual output row in the '<func>' method of the UDTF is invalid. Each row should be a tuple, list, or dict, but got '<type>'. Please make sure that the output rows are of the correct type."
]
},
"UDTF_RETURN_NOT_ITERABLE" : {
"message" : [
"The return value of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types."
"The return value of the '<func>' method of the UDTF is invalid. It should be an iterable (e.g., generator or list), but got '<type>'. Please make sure that the UDTF returns one of these types."
]
},
"UDTF_RETURN_SCHEMA_MISMATCH" : {
"message" : [
"The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the function have the same number of columns as specified in the output schema."
"The number of columns in the result does not match the specified schema. Expected column count: <expected>, Actual column count: <actual>. Please make sure the values returned by the '<func>' method have the same number of columns as specified in the output schema."
]
},
"UDTF_RETURN_TYPE_MISMATCH" : {
Expand Down
21 changes: 21 additions & 0 deletions python/pyspark/sql/tests/test_udtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,27 @@ def eval(self, a):
with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"):
TestUDTF(lit(1)).collect()

def test_udtf_with_zero_arg_and_invalid_return_value(self):
@udtf(returnType="x: int")
class TestUDTF:
def eval(self):
return 1

with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"):
TestUDTF().collect()

def test_udtf_with_invalid_return_value_in_terminate(self):
@udtf(returnType="x: int")
class TestUDTF:
def eval(self, a):
...

def terminate(self):
return 1

with self.assertRaisesRegex(PythonException, "UDTF_RETURN_NOT_ITERABLE"):
TestUDTF(lit(1)).collect()

def test_udtf_eval_with_no_return(self):
@udtf(returnType="a: int")
class TestUDTF:
Expand Down
37 changes: 27 additions & 10 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,6 +655,7 @@ def verify_result(result):
message_parameters={
"type_name": type(result).__name__,
"value": str(result),
"func": f.__name__,
},
)

Expand All @@ -669,6 +670,7 @@ def verify_result(result):
message_parameters={
"expected": str(return_type_size),
"actual": str(len(result.columns)),
"func": f.__name__,
},
)

Expand All @@ -688,22 +690,30 @@ def func(*args: Any) -> Any:
message_parameters={"method_name": f.__name__, "error": str(e)},
)

def check_return_value(res):
# Check whether the result of an arrow UDTF is iterable before
# using it to construct a pandas DataFrame.
if res is not None and not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={
"type": type(res).__name__,
"func": f.__name__,
},
)

def evaluate(*args: pd.Series):
if len(args) == 0:
yield verify_result(pd.DataFrame(func())), arrow_return_type
res = func()
check_return_value(res)
yield verify_result(pd.DataFrame(res)), arrow_return_type
else:
# Create tuples from the input pandas Series, each tuple
# represents a row across all Series.
row_tuples = zip(*args)
for row in row_tuples:
res = func(*row)
if res is not None and not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={
"type": type(res).__name__,
},
)
check_return_value(res)
yield verify_result(pd.DataFrame(res)), arrow_return_type

return evaluate
Expand Down Expand Up @@ -742,13 +752,17 @@ def verify_and_convert_result(result):
message_parameters={
"expected": str(return_type_size),
"actual": str(len(result)),
"func": f.__name__,
},
)

if not (isinstance(result, (list, dict, tuple)) or hasattr(result, "__dict__")):
raise PySparkRuntimeError(
error_class="UDTF_INVALID_OUTPUT_ROW_TYPE",
message_parameters={"type": type(result).__name__},
message_parameters={
"type": type(result).__name__,
"func": f.__name__,
},
)

return toInternal(result)
Expand All @@ -772,7 +786,10 @@ def evaluate(*a) -> tuple:
if not isinstance(res, Iterable):
raise PySparkRuntimeError(
error_class="UDTF_RETURN_NOT_ITERABLE",
message_parameters={"type": type(res).__name__},
message_parameters={
"type": type(res).__name__,
"func": f.__name__,
},
)

# If the function returns a result, we map it to the internal representation and
Expand Down