diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala index 5b56b7079a897..3bde1846f6a08 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/service/SessionHolder.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper import org.apache.spark.sql.connect.service.ExecuteKey import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC} +import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.streaming.StreamingQueryListener import org.apache.spark.util.{SystemClock, Utils} @@ -450,14 +451,14 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio */ private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)( transform: proto.Relation => LogicalPlan): LogicalPlan = { - val planCacheEnabled = Option(session) - .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) // We only cache plans that have a plan ID. - val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId + val planCacheEnabled = rel.hasCommon && rel.getCommon.hasPlanId && + Option(session) + .forall(_.sessionState.conf.getConf(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true)) def getPlanCache(rel: proto.Relation): Option[LogicalPlan] = planCache match { - case Some(cache) if planCacheEnabled && hasPlanId => + case Some(cache) if planCacheEnabled => Option(cache.getIfPresent(rel)) match { case Some(plan) => logDebug(s"Using cached plan for relation '$rel': $plan") @@ -466,20 +467,36 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio } case _ => None } - def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit = + def putPlanCache(rel: proto.Relation, plan: LogicalPlan): LogicalPlan = { planCache match { - case Some(cache) if planCacheEnabled && hasPlanId => - cache.put(rel, plan) - case _ => + case Some(cache) if planCacheEnabled => + val analyzedPlan = if (plan.analyzed) { + plan + } else { + val qe = new QueryExecution(session, plan) + if (qe.isLazyAnalysis) { + // The plan is intended to be lazily analyzed. + plan + } else { + // Make sure that the plan is fully analyzed before being cached. + qe.assertAnalyzed() + qe.analyzed + } + } + cache.put(rel, analyzedPlan) + analyzedPlan + case _ => plan } + } getPlanCache(rel) .getOrElse({ val plan = transform(rel) if (cachePlan) { putPlanCache(rel, plan) + } else { + plan } - plan }) } diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala index 21f84291a2f07..7d53c26ccb918 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/service/SparkConnectSessionHolderSuite.scala @@ -314,6 +314,8 @@ class SparkConnectSessionHolderSuite extends SharedSparkSession { case Some(expectedCachedRelations) => val cachedRelations = sessionHolder.getPlanCache.get.asMap().keySet().asScala assert(cachedRelations.size == expectedCachedRelations.size) + val cachedLogicalPlans = sessionHolder.getPlanCache.get.asMap().values().asScala + cachedLogicalPlans.foreach(plan => assert(plan.analyzed)) expectedCachedRelations.foreach(relation => assert(cachedRelations.contains(relation))) case None => assert(sessionHolder.getPlanCache.isEmpty) }