diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala index 5e3499573e9d9..0ebb9449e6ce7 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/execution/SparkConnectPlanExecution.scala @@ -27,7 +27,7 @@ import io.grpc.stub.StreamObserver import org.apache.spark.SparkEnv import org.apache.spark.connect.proto import org.apache.spark.connect.proto.ExecutePlanResponse -import org.apache.spark.sql.{DataFrame, Dataset} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connect.common.DataTypeProtoConverter import org.apache.spark.sql.connect.common.LiteralValueProtoConverter.toLiteralProto @@ -68,12 +68,9 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder) } else { DoNotCleanup } + val rel = request.getPlan.getRoot val dataframe = - Dataset.ofRows( - sessionHolder.session, - planner.transformRelation(request.getPlan.getRoot, cachePlan = true), - tracker, - shuffleCleanupMode) + sessionHolder.createDataFrame(rel, planner, Some((tracker, shuffleCleanupMode))) responseObserver.onNext(createSchemaResponse(request.getSessionId, dataframe.schema)) processAsArrowBatches(dataframe, responseObserver, executeHolder) responseObserver.onNext(MetricGenerator.createMetricsResponse(sessionHolder, dataframe)) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 56824bbb4a417..58f4c7772f402 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -117,32 +117,16 @@ class SparkConnectPlanner( private lazy val pythonExec = sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3")) - /** - * The root of the query plan is a relation and we apply the transformations to it. The resolved - * logical plan will not get cached. If the result needs to be cached, use - * `transformRelation(rel, cachePlan = true)` instead. - * @param rel - * The relation to transform. - * @return - * The resolved logical plan. - */ - @DeveloperApi - def transformRelation(rel: proto.Relation): LogicalPlan = - transformRelation(rel, cachePlan = false) - /** * The root of the query plan is a relation and we apply the transformations to it. * @param rel * The relation to transform. - * @param cachePlan - * Set to true for a performance optimization, if the plan is likely to be reused, e.g. built - * upon by further dataset transformation. The default is false. * @return * The resolved logical plan. */ @DeveloperApi - def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan = { - sessionHolder.usePlanCache(rel, cachePlan) { rel => + def transformRelation(rel: proto.Relation): LogicalPlan = { + sessionHolder.usePlanCache(rel) { rel => val plan = rel.getRelTypeCase match { // DataFrame API case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString) diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 5b56b7079a897..085a07c5a7e1b 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -32,16 +32,19 @@ import org.apache.spark.{SparkEnv, SparkException, SparkSQLException} import org.apache.spark.api.python.PythonFunction.PythonAccumulator import org.apache.spark.connect.proto import org.apache.spark.internal.{Logging, LogKeys, MDC} -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.QueryPlanningTracker import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.connect.common.InvalidPlanInput import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.ml.MLCache import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener +import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper import org.apache.spark.sql.connect.service.ExecuteKey import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC} +import org.apache.spark.sql.execution.ShuffleCleanupMode import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.{SystemClock, Utils} @@ -441,46 +444,57 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio * `spark.connect.session.planCache.enabled` is true. * @param rel * The relation to transform. - * @param cachePlan - * Whether to cache the result logical plan. * @param transform * Function to transform the relation into a logical plan. * @return * The logical plan. */ - private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)( - transform: proto.Relation => LogicalPlan): LogicalPlan = { - val planCacheEnabled = Option(session) - .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) - // We only cache plans that have a plan ID. - val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId - - def getPlanCache(rel: proto.Relation): Option[LogicalPlan] = - planCache match { - case Some(cache) if planCacheEnabled && hasPlanId => - Option(cache.getIfPresent(rel)) match { - case Some(plan) => - logDebug(s"Using cached plan for relation '$rel': $plan") - Some(plan) - case None => None - } - case _ => None - } - def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit = - planCache match { - case Some(cache) if planCacheEnabled && hasPlanId => - cache.put(rel, plan) - case _ => + private[connect] def usePlanCache(rel: proto.Relation)( + transform: proto.Relation => LogicalPlan): LogicalPlan = + planCache match { + case Some(cache) if planCacheEnabled(rel) => + Option(cache.getIfPresent(rel)) match { + case Some(plan) => + logDebug(s"Using cached plan for relation '$rel': $plan") + plan + case None => transform(rel) + } + case _ => transform(rel) + } + + /** + * Create a data frame from the supplied relation, and update the plan cache. + * + * @param rel + * A proto.Relation to create a data frame. + * @return + * The created data frame. + */ + private[connect] def createDataFrame( + rel: proto.Relation, + planner: SparkConnectPlanner, + options: Option[(QueryPlanningTracker, ShuffleCleanupMode)] = None): DataFrame = { + val df = options match { + case Some((tracker, shuffleCleanupMode)) => + Dataset.ofRows(session, planner.transformRelation(rel), tracker, shuffleCleanupMode) + case _ => Dataset.ofRows(session, planner.transformRelation(rel)) + } + if (planCache.isDefined && planCacheEnabled(rel)) { + if (df.queryExecution.isLazyAnalysis) { + planCache.get.get(rel, { () => df.queryExecution.logical }) + } else { + planCache.get.get(rel, { () => df.queryExecution.analyzed }) } + } + df + } - getPlanCache(rel) - .getOrElse({ - val plan = transform(rel) - if (cachePlan) { - putPlanCache(rel, plan) - } - plan - }) + // Return true if the plan cache is enabled for the session and the relation. + private def planCacheEnabled(rel: proto.Relation): Boolean = { + // We only cache plans that have a plan ID. + rel.hasCommon && rel.getCommon.hasPlanId && + Option(session) + .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) } // For testing. Expose the plan cache for testing purposes. diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala index 8ca021c5be39e..cabd3c7637706 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SparkConnectAnalyzeHandler.scala @@ -23,7 +23,6 @@ import io.grpc.stub.StreamObserver import org.apache.spark.connect.proto import org.apache.spark.internal.Logging -import org.apache.spark.sql.Dataset import org.apache.spark.sql.connect.common.{DataTypeProtoConverter, InvalidPlanInput, StorageLevelProtoConverter} import org.apache.spark.sql.connect.planner.SparkConnectPlanner import org.apache.spark.sql.execution.{CodegenMode, CostMode, ExtendedMode, FormattedMode, SimpleMode} @@ -59,13 +58,10 @@ private[connect] class SparkConnectAnalyzeHandler( val session = sessionHolder.session val builder = proto.AnalyzePlanResponse.newBuilder() - def transformRelation(rel: proto.Relation) = planner.transformRelation(rel, cachePlan = true) - request.getAnalyzeCase match { case proto.AnalyzePlanRequest.AnalyzeCase.SCHEMA => - val schema = Dataset - .ofRows(session, transformRelation(request.getSchema.getPlan.getRoot)) - .schema + val schema = + sessionHolder.createDataFrame(request.getSchema.getPlan.getRoot, planner).schema builder.setSchema( proto.AnalyzePlanResponse.Schema .newBuilder() @@ -73,9 +69,10 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.EXPLAIN => - val queryExecution = Dataset - .ofRows(session, transformRelation(request.getExplain.getPlan.getRoot)) - .queryExecution + val queryExecution = + sessionHolder + .createDataFrame(request.getExplain.getPlan.getRoot, planner) + .queryExecution val explainString = request.getExplain.getExplainMode match { case proto.AnalyzePlanRequest.Explain.ExplainMode.EXPLAIN_MODE_SIMPLE => queryExecution.explainString(SimpleMode) @@ -96,9 +93,8 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.TREE_STRING => - val schema = Dataset - .ofRows(session, transformRelation(request.getTreeString.getPlan.getRoot)) - .schema + val schema = + sessionHolder.createDataFrame(request.getTreeString.getPlan.getRoot, planner).schema val treeString = if (request.getTreeString.hasLevel) { schema.treeString(request.getTreeString.getLevel) } else { @@ -111,9 +107,8 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.IS_LOCAL => - val isLocal = Dataset - .ofRows(session, transformRelation(request.getIsLocal.getPlan.getRoot)) - .isLocal + val isLocal = + sessionHolder.createDataFrame(request.getIsLocal.getPlan.getRoot, planner).isLocal builder.setIsLocal( proto.AnalyzePlanResponse.IsLocal .newBuilder() @@ -121,9 +116,10 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.IS_STREAMING => - val isStreaming = Dataset - .ofRows(session, transformRelation(request.getIsStreaming.getPlan.getRoot)) - .isStreaming + val isStreaming = + sessionHolder + .createDataFrame(request.getIsStreaming.getPlan.getRoot, planner) + .isStreaming builder.setIsStreaming( proto.AnalyzePlanResponse.IsStreaming .newBuilder() @@ -131,9 +127,8 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.INPUT_FILES => - val inputFiles = Dataset - .ofRows(session, transformRelation(request.getInputFiles.getPlan.getRoot)) - .inputFiles + val inputFiles = + sessionHolder.createDataFrame(request.getInputFiles.getPlan.getRoot, planner).inputFiles builder.setInputFiles( proto.AnalyzePlanResponse.InputFiles .newBuilder() @@ -156,20 +151,18 @@ private[connect] class SparkConnectAnalyzeHandler( .build()) case proto.AnalyzePlanRequest.AnalyzeCase.SAME_SEMANTICS => - val target = Dataset.ofRows( - session, - transformRelation(request.getSameSemantics.getTargetPlan.getRoot)) - val other = Dataset.ofRows( - session, - transformRelation(request.getSameSemantics.getOtherPlan.getRoot)) + val target = + sessionHolder.createDataFrame(request.getSameSemantics.getTargetPlan.getRoot, planner) + val other = + sessionHolder.createDataFrame(request.getSameSemantics.getOtherPlan.getRoot, planner) builder.setSameSemantics( proto.AnalyzePlanResponse.SameSemantics .newBuilder() .setResult(target.sameSemantics(other))) case proto.AnalyzePlanRequest.AnalyzeCase.SEMANTIC_HASH => - val semanticHash = Dataset - .ofRows(session, transformRelation(request.getSemanticHash.getPlan.getRoot)) + val semanticHash = sessionHolder + .createDataFrame(request.getSemanticHash.getPlan.getRoot, planner) .semanticHash() builder.setSemanticHash( proto.AnalyzePlanResponse.SemanticHash @@ -177,8 +170,8 @@ private[connect] class SparkConnectAnalyzeHandler( .setResult(semanticHash)) case proto.AnalyzePlanRequest.AnalyzeCase.PERSIST => - val target = Dataset - .ofRows(session, transformRelation(request.getPersist.getRelation)) + val target = sessionHolder + .createDataFrame(request.getPersist.getRelation, planner) if (request.getPersist.hasStorageLevel) { target.persist( StorageLevelProtoConverter.toStorageLevel(request.getPersist.getStorageLevel)) @@ -188,8 +181,8 @@ private[connect] class SparkConnectAnalyzeHandler( builder.setPersist(proto.AnalyzePlanResponse.Persist.newBuilder().build()) case proto.AnalyzePlanRequest.AnalyzeCase.UNPERSIST => - val target = Dataset - .ofRows(session, transformRelation(request.getUnpersist.getRelation)) + val target = sessionHolder + .createDataFrame(request.getUnpersist.getRelation, planner) if (request.getUnpersist.hasBlocking) { target.unpersist(request.getUnpersist.getBlocking) } else { @@ -198,8 +191,8 @@ private[connect] class SparkConnectAnalyzeHandler( builder.setUnpersist(proto.AnalyzePlanResponse.Unpersist.newBuilder().build()) case proto.AnalyzePlanRequest.AnalyzeCase.GET_STORAGE_LEVEL => - val target = Dataset - .ofRows(session, transformRelation(request.getGetStorageLevel.getRelation)) + val target = sessionHolder + .createDataFrame(request.getGetStorageLevel.getRelation, planner) val storageLevel = target.storageLevel builder.setGetStorageLevel( proto.AnalyzePlanResponse.GetStorageLevel diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 21f84291a2f07..ef2f98e6db6cf 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -309,11 +309,14 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { private def assertPlanCache( sessionHolder: SessionHolder, - optionExpectedCachedRelations: Option[Set[proto.Relation]]) = { + optionExpectedCachedRelations: Option[Set[proto.Relation]], + expectAnalyzed: Boolean) = { optionExpectedCachedRelations match { case Some(expectedCachedRelations) => val cachedRelations = sessionHolder.getPlanCache.get.asMap().keySet().asScala assert(cachedRelations.size == expectedCachedRelations.size) + val cachedLogicalPlans = sessionHolder.getPlanCache.get.asMap().values().asScala + cachedLogicalPlans.foreach(plan => assert(plan.analyzed == expectAnalyzed)) expectedCachedRelations.foreach(relation => assert(cachedRelations.contains(relation))) case None => assert(sessionHolder.getPlanCache.isEmpty) } @@ -345,29 +348,29 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { .setCommon(proto.RelationCommon.newBuilder().setPlanId(Random.nextLong()).build()) .build() - // If cachePlan is false, the cache is still empty. - planner.transformRelation(random1, cachePlan = false) - assertPlanCache(sessionHolder, Some(Set())) + // Transform the relation without analysis, the cache is still empty. + val random1Plan = planner.transformRelation(random1) + assertPlanCache(sessionHolder, Some(Set()), false) - // Put a random entry in cache. - planner.transformRelation(random1, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set(random1))) + // Put a random entry in cache after analysis. + sessionHolder.createDataFrame(random1, planner) + assertPlanCache(sessionHolder, Some(Set(random1)), true) // Put another random entry in cache. - planner.transformRelation(random2, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set(random1, random2))) + sessionHolder.createDataFrame(random2, planner) + assertPlanCache(sessionHolder, Some(Set(random1, random2)), true) // Analyze query1. We only cache the root relation, and the random1 is evicted. - planner.transformRelation(query1, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set(random2, query1))) + sessionHolder.createDataFrame(query1, planner) + assertPlanCache(sessionHolder, Some(Set(random2, query1)), true) // Put another random entry in cache. - planner.transformRelation(random3, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set(query1, random3))) + sessionHolder.createDataFrame(random3, planner) + assertPlanCache(sessionHolder, Some(Set(query1, random3)), true) // Analyze query2. As query1 is accessed during the process, it should be in the cache. - planner.transformRelation(query2, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set(query1, query2))) + sessionHolder.createDataFrame(query2, planner) + assertPlanCache(sessionHolder, Some(Set(query1, query2)), true) } finally { // Set back to default value. SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 5) @@ -383,13 +386,9 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { val query = buildRelation("select 1") - // If cachePlan is false, the cache is still None. - planner.transformRelation(query, cachePlan = false) - assertPlanCache(sessionHolder, None) - - // Even if we specify "cachePlan = true", the cache is still None. - planner.transformRelation(query, cachePlan = true) - assertPlanCache(sessionHolder, None) + // The cache must be empty. + sessionHolder.createDataFrame(query, planner) + assertPlanCache(sessionHolder, None, true) } finally { // Set back to default value. SparkEnv.get.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE, 5) @@ -404,14 +403,11 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { val query = buildRelation("select 1") - // If cachePlan is false, the cache is still empty. - // Although the cache is created as cache size is greater than zero, it won't be used. - planner.transformRelation(query, cachePlan = false) - assertPlanCache(sessionHolder, Some(Set())) + // The cache must be empty. + sessionHolder.createDataFrame(query, planner) + assertPlanCache(sessionHolder, Some(Set()), true) - // Even if we specify "cachePlan = true", the cache is still empty. - planner.transformRelation(query, cachePlan = true) - assertPlanCache(sessionHolder, Some(Set())) + sessionHolder.session.conf.set(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED.key, true) } test("Test duplicate operation IDs") {