Skip to content

Commit

Permalink
Impl
Browse files Browse the repository at this point in the history
  • Loading branch information
changgyoopark-db committed Jan 24, 2025
1 parent c662441 commit ca61632
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 123 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import io.grpc.stub.StreamObserver
import org.apache.spark.SparkEnv
import org.apache.spark.connect.proto
import org.apache.spark.connect.proto.ExecutePlanResponse
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.connect.common.DataTypeProtoConverter
import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto
Expand Down Expand Up @@ -68,12 +68,9 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
} else {
DoNotCleanup
}
val rel = request.getPlan.getRoot
val dataframe =
Dataset.ofRows(
sessionHolder.session,
planner.transformRelation(request.getPlan.getRoot, cachePlan = true),
tracker,
shuffleCleanupMode)
sessionHolder.createDataFrame(rel, planner, Some((tracker, shuffleCleanupMode)))
responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema))
processAsArrowBatches(dataframe, responseObserver, executeHolder)
responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,32 +117,16 @@ class SparkConnectPlanner(
private lazy val pythonExec =
sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))

/**
* The root of the query plan is a relation and we apply the transformations to it. The resolved
* logical plan will not get cached. If the result needs to be cached, use
* `transformRelation(rel, cachePlan = true)` instead.
* @param rel
* The relation to transform.
* @return
* The resolved logical plan.
*/
@DeveloperApi
def transformRelation(rel: proto.Relation): LogicalPlan =
transformRelation(rel, cachePlan = false)

/**
* The root of the query plan is a relation and we apply the transformations to it.
* @param rel
* The relation to transform.
* @param cachePlan
* Set to true for a performance optimization, if the plan is likely to be reused, e.g. built
* upon by further dataset transformation. The default is false.
* @return
* The resolved logical plan.
*/
@DeveloperApi
def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan = {
sessionHolder.usePlanCache(rel, cachePlan) { rel =>
def transformRelation(rel: proto.Relation): LogicalPlan = {
val (logicalPlan, _) = sessionHolder.usePlanCache(rel) { rel =>
val plan = rel.getRelTypeCase match {
// DataFrame API
case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)
Expand Down Expand Up @@ -238,6 +222,7 @@ class SparkConnectPlanner(
}
plan
}
logicalPlan
}

private def transformRelationPlugin(extension: ProtoAny): LogicalPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,19 @@ import org.apache.spark.{SparkEnv, SparkException, SparkSQLException}
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.connect.proto
import org.apache.spark.internal.{Logging, LogKeys, MDC}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.QueryPlanningTracker
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.ml.MLCache
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
import org.apache.spark.sql.connect.service.ExecuteKey
import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC}
import org.apache.spark.sql.execution.ShuffleCleanupMode
import org.apache.spark.sql.streaming.StreamingQueryListener
import org.apache.spark.util.{SystemClock, Utils}

Expand Down Expand Up @@ -441,46 +444,58 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
* `spark.connect.session.planCache.enabled` is true.
* @param rel
* The relation to transform.
* @param cachePlan
* Whether to cache the result logical plan.
* @param transform
* Function to transform the relation into a logical plan.
* @return
* The logical plan.
* The logical plan and a flag indicating that the plan cache was hit.
*/
private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)(
transform: proto.Relation => LogicalPlan): LogicalPlan = {
val planCacheEnabled = Option(session)
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
// We only cache plans that have a plan ID.
val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId

def getPlanCache(rel: proto.Relation): Option[LogicalPlan] =
planCache match {
case Some(cache) if planCacheEnabled && hasPlanId =>
Option(cache.getIfPresent(rel)) match {
case Some(plan) =>
logDebug(s"Using cached plan for relation '$rel': $plan")
Some(plan)
case None => None
}
case _ => None
}
def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit =
planCache match {
case Some(cache) if planCacheEnabled && hasPlanId =>
cache.put(rel, plan)
case _ =>
private[connect] def usePlanCache(rel: proto.Relation)(
transform: proto.Relation => LogicalPlan): (LogicalPlan, Boolean) =
planCache match {
case Some(cache) if planCacheEnabled(rel) =>
Option(cache.getIfPresent(rel)) match {
case Some(plan) =>
logDebug(s"Using cached plan for relation '$rel': $plan")
(plan, true)
case None => (transform(rel), false)
}
case _ => (transform(rel), false)
}

/**
* Create a data frame from the supplied relation, and update the plan cache.
*
* @param rel
* A proto.Relation to create a data frame.
* @return
* The created data frame.
*/
private[connect] def createDataFrame(
rel: proto.Relation,
planner: SparkConnectPlanner,
options: Option[(QueryPlanningTracker, ShuffleCleanupMode)] = None): DataFrame = {
val (plan, cacheHit) = usePlanCache(rel)(r => planner.transformRelation(r))
val df = options match {
case Some((tracker, shuffleCleanupMode)) =>
Dataset.ofRows(session, plan, tracker, shuffleCleanupMode)
case _ => Dataset.ofRows(session, plan)
}
if (!cacheHit && planCache.isDefined && planCacheEnabled(rel)) {
if (df.queryExecution.isLazyAnalysis) {
planCache.get.get(rel, { () => df.queryExecution.logical })
} else {
planCache.get.get(rel, { () => df.queryExecution.analyzed })
}
}
df
}

getPlanCache(rel)
.getOrElse({
val plan = transform(rel)
if (cachePlan) {
putPlanCache(rel, plan)
}
plan
})
// Return true if the plan cache is enabled for the session and the relation.
private def planCacheEnabled(rel: proto.Relation): Boolean = {
// We only cache plans that have a plan ID.
rel.hasCommon && rel.getCommon.hasPlanId &&
Option(session)
.forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
}

// For testing. Expose the plan cache for testing purposes.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ import io.grpc.stub.StreamObserver

import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter}
import org.apache.spark.sql.connect.planner.SparkConnectPlanner
import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExtendedMode, FormattedMode, SimpleMode}
Expand Down Expand Up @@ -59,23 +58,21 @@ private[connect] class SparkConnectAnalyzeHandler(
val session = sessionHolder.session
val builder = proto.AnalyzePlanResponse.newBuilder()

def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true)

request.getAnalyzeCase match {
case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA =>
val schema = Dataset
.ofRows(session, transformRelation(request.getSchema.getPlan.getRoot))
.schema
val schema =
sessionHolder.createDataFrame(request.getSchema.getPlan.getRoot, planner).schema
builder.setSchema(
proto.AnalyzePlanResponse.Schema
.newBuilder()
.setSchema(DataTypeProtoConverter.toConnectProtoType(schema))
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN =>
val queryExecution = Dataset
.ofRows(session, transformRelation(request.getExplain.getPlan.getRoot))
.queryExecution
val queryExecution =
sessionHolder
.createDataFrame(request.getExplain.getPlan.getRoot, planner)
.queryExecution
val explainString = request.getExplain.getExplainMode match {
case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE =>
queryExecution.explainString(SimpleMode)
Expand All @@ -96,9 +93,8 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING =>
val schema = Dataset
.ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot))
.schema
val schema =
sessionHolder.createDataFrame(request.getTreeString.getPlan.getRoot, planner).schema
val treeString = if (request.getTreeString.hasLevel) {
schema.treeString(request.getTreeString.getLevel)
} else {
Expand All @@ -111,29 +107,28 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL =>
val isLocal = Dataset
.ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot))
.isLocal
val isLocal =
sessionHolder.createDataFrame(request.getIsLocal.getPlan.getRoot, planner).isLocal
builder.setIsLocal(
proto.AnalyzePlanResponse.IsLocal
.newBuilder()
.setIsLocal(isLocal)
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING =>
val isStreaming = Dataset
.ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot))
.isStreaming
val isStreaming =
sessionHolder
.createDataFrame(request.getIsStreaming.getPlan.getRoot, planner)
.isStreaming
builder.setIsStreaming(
proto.AnalyzePlanResponse.IsStreaming
.newBuilder()
.setIsStreaming(isStreaming)
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES =>
val inputFiles = Dataset
.ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot))
.inputFiles
val inputFiles =
sessionHolder.createDataFrame(request.getInputFiles.getPlan.getRoot, planner).inputFiles
builder.setInputFiles(
proto.AnalyzePlanResponse.InputFiles
.newBuilder()
Expand All @@ -156,29 +151,27 @@ private[connect] class SparkConnectAnalyzeHandler(
.build())

case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS =>
val target = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getTargetPlan.getRoot))
val other = Dataset.ofRows(
session,
transformRelation(request.getSameSemantics.getOtherPlan.getRoot))
val target =
sessionHolder.createDataFrame(request.getSameSemantics.getTargetPlan.getRoot, planner)
val other =
sessionHolder.createDataFrame(request.getSameSemantics.getOtherPlan.getRoot, planner)
builder.setSameSemantics(
proto.AnalyzePlanResponse.SameSemantics
.newBuilder()
.setResult(target.sameSemantics(other)))

case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH =>
val semanticHash = Dataset
.ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot))
val semanticHash = sessionHolder
.createDataFrame(request.getSemanticHash.getPlan.getRoot, planner)
.semanticHash()
builder.setSemanticHash(
proto.AnalyzePlanResponse.SemanticHash
.newBuilder()
.setResult(semanticHash))

case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getPersist.getRelation))
val target = sessionHolder
.createDataFrame(request.getPersist.getRelation, planner)
if (request.getPersist.hasStorageLevel) {
target.persist(
StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel))
Expand All @@ -188,8 +181,8 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setPersist(proto.AnalyzePlanResponse.Persist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST =>
val target = Dataset
.ofRows(session, transformRelation(request.getUnpersist.getRelation))
val target = sessionHolder
.createDataFrame(request.getUnpersist.getRelation, planner)
if (request.getUnpersist.hasBlocking) {
target.unpersist(request.getUnpersist.getBlocking)
} else {
Expand All @@ -198,8 +191,8 @@ private[connect] class SparkConnectAnalyzeHandler(
builder.setUnpersist(proto.AnalyzePlanResponse.Unpersist.newBuilder().build())

case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL =>
val target = Dataset
.ofRows(session, transformRelation(request.getGetStorageLevel.getRelation))
val target = sessionHolder
.createDataFrame(request.getGetStorageLevel.getRelation, planner)
val storageLevel = target.storageLevel
builder.setGetStorageLevel(
proto.AnalyzePlanResponse.GetStorageLevel
Expand Down
Loading

0 comments on commit ca61632

Please sign in to comment.