Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49899][PYTHON][SS] Support deleteIfExists for TransformWithStateInPandas #48373

Closed
wants to merge 12 commits into from
8 changes: 6 additions & 2 deletions python/pyspark/sql/pandas/group_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,7 +491,7 @@ def transformWithStateUDF(
statefulProcessorApiClient: StatefulProcessorApiClient,
key: Any,
inputRows: Iterator["PandasDataFrameLike"],
) -> Iterator["PandasDataFrameLike"]:
) -> Iterator[Iterator["PandasDataFrameLike"]]:
handle = StatefulProcessorHandle(statefulProcessorApiClient)

if statefulProcessorApiClient.handle_state == StatefulProcessorHandleState.CREATED:
Expand All @@ -503,7 +503,11 @@ def transformWithStateUDF(
statefulProcessorApiClient.set_implicit_key(key)
result = statefulProcessor.handleInputRows(key, inputRows)

return result
try:
yield result
finally:
statefulProcessor.close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall we set handle state to CLOSE here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just realized an issue that this is actually being called after processing each grouping key instead of finishing processing all keys for a microbatch. I'll need to revisit this to see if there's a good way to handle this (I cannot think about a good way to detect if the current key is the last key to process right now), if it's not a quick fix, we can probably exclude it for now and have a followup PR fixing it. cc @HeartSaVioR

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could try injecting a dummy row at the end of the iterator in writeNextInputToArrowStream indicating all the keys have been processed, but I'll need to do some experiments first.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't feel like current interface would give you such information - we'll probably need to have another control message to send the signal from JVM to Python (UDF). I agree this may take time, but probably need to mark it as a blocker so that we address before the release.

statefulProcessorApiClient.remove_implicit_key()

if isinstance(outputStructType, str):
outputStructType = cast(StructType, _parse_datatype_string(outputStructType))
Expand Down
2 changes: 1 addition & 1 deletion python/pyspark/sql/pandas/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1188,5 +1188,5 @@ def dump_stream(self, iterator, stream):
Read through an iterator of (iterator of pandas DataFrame), serialize them to Arrow
RecordBatches, and write batches to stream.
"""
result = [(b, t) for x in iterator for y, t in x for b in y]
result = [(b, t) for x in iterator for y, t in x for a in y for b in a]
super().dump_stream(result, stream)
109 changes: 55 additions & 54 deletions python/pyspark/sql/streaming/StateMessage_pb2.py

Large diffs are not rendered by default.

6 changes: 5 additions & 1 deletion python/pyspark/sql/streaming/StateMessage_pb2.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
Expand Down Expand Up @@ -76,21 +77,24 @@ class StateResponse(_message.Message):
) -> None: ...

class StatefulProcessorCall(_message.Message):
__slots__ = ("setHandleState", "getValueState", "getListState", "getMapState")
__slots__ = ("setHandleState", "getValueState", "getListState", "getMapState", "deleteIfExists")
SETHANDLESTATE_FIELD_NUMBER: _ClassVar[int]
GETVALUESTATE_FIELD_NUMBER: _ClassVar[int]
GETLISTSTATE_FIELD_NUMBER: _ClassVar[int]
GETMAPSTATE_FIELD_NUMBER: _ClassVar[int]
DELETEIFEXISTS_FIELD_NUMBER: _ClassVar[int]
setHandleState: SetHandleState
getValueState: StateCallCommand
getListState: StateCallCommand
getMapState: StateCallCommand
deleteIfExists: StateCallCommand
def __init__(
self,
setHandleState: _Optional[_Union[SetHandleState, _Mapping]] = ...,
getValueState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
getListState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
getMapState: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
deleteIfExists: _Optional[_Union[StateCallCommand, _Mapping]] = ...,
) -> None: ...

class StateVariableRequest(_message.Message):
Expand Down
6 changes: 6 additions & 0 deletions python/pyspark/sql/streaming/stateful_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,12 @@ def getMapState(
state_name,
)

def deleteIfExists(self, state_name: str) -> None:
"""
Function to delete and purge state variable if defined previously
"""
self.stateful_processor_api_client.delete_if_exists(state_name)


class StatefulProcessor(ABC):
"""
Expand Down
15 changes: 15 additions & 0 deletions python/pyspark/sql/streaming/stateful_processor_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,21 @@ def get_map_state(
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error initializing map state: " f"{response_message[1]}")

def delete_if_exists(self, state_name: str) -> None:
import pyspark.sql.streaming.StateMessage_pb2 as stateMessage

state_call_command = stateMessage.StateCallCommand()
state_call_command.stateName = state_name
call = stateMessage.StatefulProcessorCall(deleteIfExists=state_call_command)
message = stateMessage.StateRequest(statefulProcessorCall=call)

self._send_proto_message(message.SerializeToString())
response_message = self._receive_proto_message()
status = response_message[0]
if status != 0:
# TODO(SPARK-49233): Classify user facing errors.
raise PySparkRuntimeError(f"Error deleting state: " f"{response_message[1]}")

def _send_proto_message(self, message: bytes) -> None:
# Writing zero here to indicate message version. This allows us to evolve the message
# format or even changing the message protocol in the future.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from typing import cast

from pyspark import SparkConf
from pyspark.errors import PySparkRuntimeError
from pyspark.sql.functions import split
from pyspark.sql.types import (
StringType,
Expand Down Expand Up @@ -364,15 +365,19 @@ def check_results(batch_df, batch_id):
input_dir.cleanup()


class SimpleStatefulProcessor(StatefulProcessor):
class SimpleStatefulProcessor(StatefulProcessor, unittest.TestCase):
dict = {0: {"0": 1, "1": 2}, 1: {"0": 4, "1": 3}}
batch_id = 0

def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", IntegerType(), True)])
self.num_violations_state = handle.getValueState("numViolations", state_schema)
self.temp_state = handle.getValueState("tempState", state_schema)
handle.deleteIfExists("tempState")

def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
with self.assertRaisesRegex(PySparkRuntimeError, "Error checking value state exists"):
self.temp_state.exists()
new_violations = 0
count = 0
key_str = key[0]
Expand Down Expand Up @@ -400,10 +405,12 @@ def close(self) -> None:

# A stateful processor that inherit all behavior of SimpleStatefulProcessor except that it use
# ttl state with a large timeout.
class SimpleTTLStatefulProcessor(SimpleStatefulProcessor):
class SimpleTTLStatefulProcessor(SimpleStatefulProcessor, unittest.TestCase):
def init(self, handle: StatefulProcessorHandle) -> None:
state_schema = StructType([StructField("value", IntegerType(), True)])
self.num_violations_state = handle.getValueState("numViolations", state_schema, 30000)
self.temp_state = handle.getValueState("tempState", state_schema)
handle.deleteIfExists("tempState")


class TTLStatefulProcessor(StatefulProcessor):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ message StatefulProcessorCall {
StateCallCommand getValueState = 2;
StateCallCommand getListState = 3;
StateCallCommand getMapState = 4;
StateCallCommand deleteIfExists = 5;
}
}

Expand Down
Loading