Skip to content

Commit

Permalink
[SPARK-48258][PYTHON][CONNECT] Checkpoint and localCheckpoint in Spar…
Browse files Browse the repository at this point in the history
…k Connect

### What changes were proposed in this pull request?

This PR proposes to `DataFrame.checkpoint` and `DataFrame.localCheckpoint` API in Spark Connect.

#### Overview

![Screenshot 2024-05-16 at 10 39 25 AM](https://github.com/apache/spark/assets/6477701/c5c4754f-3d5e-4f4a-8f9d-a7218ce49320)

1. Spark Connect Client invokes [local]checkpoint
    - Connects to the server, store (Session UI, UUID) <> Checkpointed DataFrame
2. Execute [local]checkpoint
3. Returns UUID for the checkedpointed DataFrame.
   - Client side holds the UUID with truncated (replaced) the protobuf message
4. When the DataFrame in client side is garbage-collected, it is invoked to clear the state within Spark Connect server.
5. If the checkpointed RDD is not referred anymore (e.g., not even by temp view as an example), it is cleaned by ContextCleaner (which runs separately, and periodically)
6. *When the session is closed, it attempts to clear all mapped state in Spark Connect server (because it is not guaranteed to call `DataFrame.__del__` in Python upon garbage-collection)
7. *If the checkpointed RDD is not referred anymore (e.g., not even by temp view as an example), it is cleaned by ContextCleaner (which runs separately, and periodically)

*In 99.999% cases, the state (map<(session_id, uuid), c'p'dataframe>) will be cleared when DataFrame is garbage-collected, e.g., unless there are some crashes. Practically, Py4J also leverages to clean up their Java objects. For 0.001% cases, the 6. and 7. address them. Both steps  happen when session is closed, and session holder is released, see also [#41580](#41580).

#### Command/RPCs

Reuse `CachedRemoteRelation` (from [#41580](#41580))

```proto
message Command {
  oneof command_type {
    ...
    CheckpointCommand checkpoint_command = 14;
    RemoveCachedRemoteRelationCommand remove_cached_remote_relation_command = 15;
    ...
  }
}

// Command to remove `CashedRemoteRelation`
message RemoveCachedRemoteRelationCommand {
  // (Required) The remote to be related
  CachedRemoteRelation relation = 1;
}

message CheckpointCommand {
  // (Required) The logical plan to checkpoint.
  Relation relation = 1;

  // (Optional) Locally checkpoint using a local temporary
  // directory in Spark Connect server (Spark Driver)
  optional bool local = 2;

  // (Optional) Whether to checkpoint this dataframe immediately.
  optional bool eager = 3;
}

message CheckpointCommandResult {
  // (Required) The logical plan checkpointed.
  CachedRemoteRelation relation = 1;
}
```

```proto
message ExecutePlanResponse {

  ...

  oneof response_type {

    ...

    CheckpointCommandResult checkpoint_command_result = 19;
  }

  ...

  message Checkpoint {
    // (Required) The logical plan checkpointed.
    CachedRemoteRelation relation = ...;
  }
}
```

#### Usage

```bash
./sbin/start-connect-server.sh --conf spark.checkpoint.dir=/path/to/checkpoint
```

```python
spark.range(1).localCheckpoint()
spark.range(1).checkpoint()
```

### Why are the changes needed?

For feature parity without Spark Connect.

### Does this PR introduce _any_ user-facing change?

Yes, it adds both `DataFrame.checkpoint` and `DataFrame.localCheckpoint` API in Spark Connect.

### How was this patch tested?

Unittests, and manually tested as below:

**Code**

```bash
./bin/pyspark --remote "local[*]"
```

```python
>>> df = spark.range(1).localCheckpoint()
>>> df.explain(True)
== Parsed Logical Plan ==
LogicalRDD [id#1L], false

== Analyzed Logical Plan ==
id: bigint
LogicalRDD [id#1L], false

== Optimized Logical Plan ==
LogicalRDD [id#1L], false

== Physical Plan ==
*(1) Scan ExistingRDD[id#1L]

>>> df._plan
<pyspark.sql.connect.plan.CachedRemoteRelation object at 0x147734a50>
>>> del df
```

**Logs**

```
...
{"ts":"2024-05-14T06:18:01.711Z","level":"INFO","msg":"Caching DataFrame with id 7316f315-d20d-446d-b5e7-ac848870e280","context":{"dataframe_id":"7316f315-d20d-446d-b5e7-ac848870e280"},"logger":"SparkConnectAnalyzeHandler"}
...
{"ts":"2024-05-14T06:18:11.718Z","level":"INFO","msg":"Removing DataFrame with id 7316f315-d20d-446d-b5e7-ac848870e280 from the cache","context":{"dataframe_id":"7316f315-d20d-446d-b5e7-ac848870e280"},"logger":"SparkConnectPlanner"}
...
```

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes #46570 from HyukjinKwon/SPARK-48258.

Authored-by: Hyukjin Kwon <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
HyukjinKwon committed May 21, 2024
1 parent 0393ab4 commit 7d6bb74
Show file tree
Hide file tree
Showing 18 changed files with 656 additions and 295 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,9 @@ message ExecutePlanResponse {
// (Optional) Intermediate query progress reports.
ExecutionProgress execution_progress = 18;

// Response for command that checkpoints a DataFrame.
CheckpointCommandResult checkpoint_command_result = 19;

// Support arbitrary result objects.
google.protobuf.Any extension = 999;
}
Expand Down Expand Up @@ -1048,6 +1051,11 @@ message FetchErrorDetailsResponse {
}
}

message CheckpointCommandResult {
// (Required) The logical plan checkpointed.
CachedRemoteRelation relation = 1;
}

// Main interface for the SparkConnect service.
service SparkConnectService {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ message Command {
StreamingQueryListenerBusCommand streaming_query_listener_bus_command = 11;
CommonInlineUserDefinedDataSource register_data_source = 12;
CreateResourceProfileCommand create_resource_profile_command = 13;
CheckpointCommand checkpoint_command = 14;
RemoveCachedRemoteRelationCommand remove_cached_remote_relation_command = 15;

// This field is used to mark extensions to the protocol. When plugins generate arbitrary
// Commands they can add them here. During the planning the correct resolution is done.
Expand Down Expand Up @@ -484,3 +486,21 @@ message CreateResourceProfileCommandResult {
// (Required) Server-side generated resource profile id.
int32 profile_id = 1;
}

// Command to remove `CashedRemoteRelation`
message RemoveCachedRemoteRelationCommand {
// (Required) The remote to be related
CachedRemoteRelation relation = 1;
}

message CheckpointCommand {
// (Required) The logical plan to checkpoint.
Relation relation = 1;

// (Optional) Locally checkpoint using a local temporary
// directory in Spark Connect server (Spark Driver)
optional bool local = 2;

// (Optional) Whether to checkpoint this dataframe immediately.
optional bool eager = 3;
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

package org.apache.spark.sql.connect.planner

import java.util.UUID

import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.util.Try
Expand All @@ -33,13 +35,13 @@ import org.apache.spark.{Partition, SparkEnv, TaskContext}
import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.api.python.{PythonEvalType, SimplePythonFunction}
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.{CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.{CheckpointCommand, CreateResourceProfileCommand, ExecutePlanResponse, SqlCommand, StreamingForeachFunction, StreamingQueryCommand, StreamingQueryCommandResult, StreamingQueryInstanceId, StreamingQueryManagerCommand, StreamingQueryManagerCommandResult, WriteStreamOperationStart, WriteStreamOperationStartResult}
import org.apache.spark.connect.proto.ExecutePlanResponse.SqlCommandResult
import org.apache.spark.connect.proto.Parse.ParseFormat
import org.apache.spark.connect.proto.StreamingQueryManagerCommandResult.StreamingQueryInstance
import org.apache.spark.connect.proto.WriteStreamOperationStart.TriggerCase
import org.apache.spark.internal.{Logging, MDC}
import org.apache.spark.internal.LogKeys.SESSION_ID
import org.apache.spark.internal.LogKeys.{DATAFRAME_ID, SESSION_ID}
import org.apache.spark.ml.{functions => MLFunctions}
import org.apache.spark.resource.{ExecutorResourceRequest, ResourceProfile, TaskResourceProfile, TaskResourceRequest}
import org.apache.spark.sql.{Column, Dataset, Encoders, ForeachWriter, Observation, RelationalGroupedDataset, SparkSession}
Expand Down Expand Up @@ -2581,6 +2583,10 @@ class SparkConnectPlanner(
handleCreateResourceProfileCommand(
command.getCreateResourceProfileCommand,
responseObserver)
case proto.Command.CommandTypeCase.CHECKPOINT_COMMAND =>
handleCheckpointCommand(command.getCheckpointCommand, responseObserver)
case proto.Command.CommandTypeCase.REMOVE_CACHED_REMOTE_RELATION_COMMAND =>
handleRemoveCachedRemoteRelationCommand(command.getRemoveCachedRemoteRelationCommand)

case _ => throw new UnsupportedOperationException(s"$command not supported.")
}
Expand Down Expand Up @@ -3507,6 +3513,47 @@ class SparkConnectPlanner(
.build())
}

private def handleCheckpointCommand(
checkpointCommand: CheckpointCommand,
responseObserver: StreamObserver[proto.ExecutePlanResponse]): Unit = {
val target = Dataset
.ofRows(session, transformRelation(checkpointCommand.getRelation))
val checkpointed = if (checkpointCommand.hasLocal && checkpointCommand.hasEager) {
target.localCheckpoint(eager = checkpointCommand.getEager)
} else if (checkpointCommand.hasLocal) {
target.localCheckpoint()
} else if (checkpointCommand.hasEager) {
target.checkpoint(eager = checkpointCommand.getEager)
} else {
target.checkpoint()
}

val dfId = UUID.randomUUID().toString
logInfo(log"Caching DataFrame with id ${MDC(DATAFRAME_ID, dfId)}")
sessionHolder.cacheDataFrameById(dfId, checkpointed)

executeHolder.eventsManager.postFinished()
responseObserver.onNext(
proto.ExecutePlanResponse
.newBuilder()
.setSessionId(sessionId)
.setServerSideSessionId(sessionHolder.serverSessionId)
.setCheckpointCommandResult(
proto.CheckpointCommandResult
.newBuilder()
.setRelation(proto.CachedRemoteRelation.newBuilder().setRelationId(dfId).build())
.build())
.build())
}

private def handleRemoveCachedRemoteRelationCommand(
removeCachedRemoteRelationCommand: proto.RemoveCachedRemoteRelationCommand): Unit = {
val dfId = removeCachedRemoteRelationCommand.getRelation.getRelationId
logInfo(log"Removing DataFrame with id ${MDC(DATAFRAME_ID, dfId)} from the cache")
sessionHolder.removeCachedDataFrame(dfId)
executeHolder.eventsManager.postFinished()
}

private val emptyLocalRelation = LocalRelation(
output = AttributeReference("value", StringType, false)() :: Nil,
data = Seq.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,9 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
val eventManager: SessionEventsManager = SessionEventsManager(this, new SystemClock())

// Mapping from relation ID (passed to client) to runtime dataframe. Used for callbacks like
// foreachBatch() in Streaming. Lazy since most sessions don't need it.
private lazy val dataFrameCache: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap()
// foreachBatch() in Streaming, and DataFrame.checkpoint API. Lazy since most sessions don't
// need it.
private[spark] lazy val dataFrameCache: ConcurrentMap[String, DataFrame] = new ConcurrentHashMap()

// Mapping from id to StreamingQueryListener. Used for methods like removeListener() in
// StreamingQueryManager.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ object SparkConnectService extends Logging {
previoslyObservedSessionId)
}

// For testing
private[spark] def getOrCreateIsolatedSession(
userId: String, sessionId: String): SessionHolder = {
getOrCreateIsolatedSession(userId, sessionId, None)
}

/**
* If there are no executions, return Left with System.currentTimeMillis of last active
* execution. Otherwise return Right with list of ExecuteInfo of all executions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
class SparkFrameMethodsParityTests(
SparkFrameMethodsTestsMixin, TestUtils, PandasOnSparkTestUtils, ReusedConnectTestCase
):
@unittest.skip("Test depends on checkpoint which is not supported from Spark Connect.")
@unittest.skip("Test depends on SparkContext which is not supported from Spark Connect.")
def test_checkpoint(self):
super().test_checkpoint()

Expand All @@ -34,10 +34,6 @@ def test_checkpoint(self):
def test_coalesce(self):
super().test_coalesce()

@unittest.skip("Test depends on localCheckpoint which is not supported from Spark Connect.")
def test_local_checkpoint(self):
super().test_local_checkpoint()

@unittest.skip(
"Test depends on RDD, and cannot use SQL expression due to Catalyst optimization"
)
Expand Down
12 changes: 11 additions & 1 deletion python/pyspark/sql/connect/client/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,11 @@
from pyspark.sql.connect.profiler import ConnectProfilerCollector
from pyspark.sql.connect.client.reattach import ExecutePlanResponseReattachableIterator
from pyspark.sql.connect.client.retries import RetryPolicy, Retrying, DefaultPolicy
from pyspark.sql.connect.conversion import storage_level_to_proto, proto_to_storage_level
from pyspark.sql.connect.conversion import (
storage_level_to_proto,
proto_to_storage_level,
proto_to_remote_cached_dataframe,
)
import pyspark.sql.connect.proto as pb2
import pyspark.sql.connect.proto.base_pb2_grpc as grpc_lib
import pyspark.sql.connect.types as types
Expand Down Expand Up @@ -1400,6 +1404,12 @@ def handle_response(
if b.HasField("create_resource_profile_command_result"):
profile_id = b.create_resource_profile_command_result.profile_id
yield {"create_resource_profile_command_result": profile_id}
if b.HasField("checkpoint_command_result"):
yield {
"checkpoint_command_result": proto_to_remote_cached_dataframe(
b.checkpoint_command_result.relation
)
}

try:
if self._use_reattachable_execute:
Expand Down
23 changes: 17 additions & 6 deletions python/pyspark/sql/connect/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,12 +48,10 @@
import pyspark.sql.connect.proto as pb2
from pyspark.sql.pandas.types import to_arrow_schema, _dedup_names, _deduplicate_field_names

from typing import (
Any,
Callable,
Sequence,
List,
)
from typing import Any, Callable, Sequence, List, TYPE_CHECKING

if TYPE_CHECKING:
from pyspark.sql.connect.dataframe import DataFrame


class LocalDataToArrowConversion:
Expand Down Expand Up @@ -570,3 +568,16 @@ def proto_to_storage_level(storage_level: pb2.StorageLevel) -> StorageLevel:
deserialized=storage_level.deserialized,
replication=storage_level.replication,
)


def proto_to_remote_cached_dataframe(relation: pb2.CachedRemoteRelation) -> "DataFrame":
assert relation is not None and isinstance(relation, pb2.CachedRemoteRelation)

from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.session import SparkSession
import pyspark.sql.connect.plan as plan

return DataFrame(
plan=plan.CachedRemoteRelation(relation.relation_id),
session=SparkSession.active(),
)
69 changes: 54 additions & 15 deletions python/pyspark/sql/connect/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#

# mypy: disable-error-code="override"

from pyspark.sql.connect.proto import base_pb2 as spark_dot_connect_dot_base__pb2
from pyspark.errors.exceptions.base import (
SessionNotSameException,
PySparkIndexError,
Expand Down Expand Up @@ -138,6 +138,41 @@ def __init__(
# by __repr__ and _repr_html_ while eager evaluation opens.
self._support_repr_html = False
self._cached_schema: Optional[StructType] = None
self._cached_remote_relation_id: Optional[str] = None

def __del__(self) -> None:
# If session is already closed, all cached DataFrame should be released.
if not self._session.client.is_closed and self._cached_remote_relation_id is not None:
try:
command = plan.RemoveRemoteCachedRelation(
plan.CachedRemoteRelation(relationId=self._cached_remote_relation_id)
).command(session=self._session.client)
req = self._session.client._execute_plan_request_with_metadata()
if self._session.client._user_id:
req.user_context.user_id = self._session.client._user_id
req.plan.command.CopyFrom(command)

for attempt in self._session.client._retrying():
with attempt:
# !!HACK ALERT!!
# unary_stream does not work on Python's exit for an unknown reasons
# Therefore, here we open unary_unary channel instead.
# See also :class:`SparkConnectServiceStub`.
request_serializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanRequest.SerializeToString
)
response_deserializer = (
spark_dot_connect_dot_base__pb2.ExecutePlanResponse.FromString
)
channel = self._session.client._channel.unary_unary(
"/spark.connect.SparkConnectService/ExecutePlan",
request_serializer=request_serializer,
response_deserializer=response_deserializer,
)
metadata = self._session.client._builder.metadata()
channel(req, metadata=metadata) # type: ignore[arg-type]
except Exception as e:
warnings.warn(f"RemoveRemoteCachedRelation failed with exception: {e}.")

def __reduce__(self) -> Tuple:
"""
Expand Down Expand Up @@ -2096,19 +2131,25 @@ def writeTo(self, table: str) -> "DataFrameWriterV2":
def offset(self, n: int) -> ParentDataFrame:
return DataFrame(plan.Offset(child=self._plan, offset=n), session=self._session)

if not is_remote_only():
def checkpoint(self, eager: bool = True) -> "DataFrame":
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager)
_, properties = self._session.client.execute_command(cmd.command(self._session.client))
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId
return checkpointed

def localCheckpoint(self, eager: bool = True) -> "DataFrame":
cmd = plan.Checkpoint(child=self._plan, local=True, eager=eager)
_, properties = self._session.client.execute_command(cmd.command(self._session.client))
assert "checkpoint_command_result" in properties
checkpointed = properties["checkpoint_command_result"]
assert isinstance(checkpointed._plan, plan.CachedRemoteRelation)
checkpointed._cached_remote_relation_id = checkpointed._plan._relationId
return checkpointed

def checkpoint(self, eager: bool = True) -> "DataFrame":
raise PySparkNotImplementedError(
error_class="NOT_IMPLEMENTED",
message_parameters={"feature": "checkpoint()"},
)

def localCheckpoint(self, eager: bool = True) -> "DataFrame":
raise PySparkNotImplementedError(
error_class="NOT_IMPLEMENTED",
message_parameters={"feature": "localCheckpoint()"},
)
if not is_remote_only():

def toJSON(self, use_unicode: bool = True) -> "RDD[str]":
raise PySparkNotImplementedError(
Expand Down Expand Up @@ -2203,8 +2244,6 @@ def _test() -> None:
if not is_remote_only():
del pyspark.sql.dataframe.DataFrame.toJSON.__doc__
del pyspark.sql.dataframe.DataFrame.rdd.__doc__
del pyspark.sql.dataframe.DataFrame.checkpoint.__doc__
del pyspark.sql.dataframe.DataFrame.localCheckpoint.__doc__

globs["spark"] = (
PySparkSession.builder.appName("sql.connect.dataframe tests")
Expand Down
32 changes: 31 additions & 1 deletion python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1785,9 +1785,39 @@ def command(self, session: "SparkConnectClient") -> proto.Command:
return cmd


# Catalog API (internal-only)
class RemoveRemoteCachedRelation(LogicalPlan):
def __init__(self, relation: CachedRemoteRelation) -> None:
super().__init__(None)
self._relation = relation

def command(self, session: "SparkConnectClient") -> proto.Command:
plan = self._create_proto_relation()
plan.cached_remote_relation.relation_id = self._relation._relationId
cmd = proto.Command()
cmd.remove_cached_remote_relation_command.relation.CopyFrom(plan.cached_remote_relation)
return cmd


class Checkpoint(LogicalPlan):
def __init__(self, child: Optional["LogicalPlan"], local: bool, eager: bool) -> None:
super().__init__(child)
self._local = local
self._eager = eager

def command(self, session: "SparkConnectClient") -> proto.Command:
cmd = proto.Command()
assert self._child is not None
cmd.checkpoint_command.CopyFrom(
proto.CheckpointCommand(
relation=self._child.plan(session),
local=self._local,
eager=self._eager,
)
)
return cmd


# Catalog API (internal-only)
class CurrentDatabase(LogicalPlan):
def __init__(self) -> None:
super().__init__(None)
Expand Down
214 changes: 108 additions & 106 deletions python/pyspark/sql/connect/proto/base_pb2.py

Large diffs are not rendered by default.

Loading

0 comments on commit 7d6bb74

Please sign in to comment.