Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-51085][SQL] Restore SQLContext Companion #49964

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions sql/api/src/main/scala/org/apache/spark/sql/SQLContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -980,7 +980,6 @@ abstract class SQLContext private[sql] (val sparkSession: SparkSession)
*/
private[sql] trait SQLContextCompanion {
private[sql] type SQLContextImpl <: SQLContext
private[sql] type SparkContextImpl <: SparkContext

/**
* Get the singleton SQLContext if it exists or create a new one using the given SparkContext.
Expand All @@ -994,7 +993,7 @@ private[sql] trait SQLContextCompanion {
* @since 1.5.0
*/
@deprecated("Use SparkSession.builder instead", "2.0.0")
def getOrCreate(sparkContext: SparkContextImpl): SQLContextImpl
def getOrCreate(sparkContext: SparkContext): SQLContextImpl

/**
* Changes the SQLContext that will be returned in this thread and its children when
Expand All @@ -1019,3 +1018,13 @@ private[sql] trait SQLContextCompanion {
SparkSession.clearActiveSession()
}
}

object SQLContext extends SQLContextCompanion {
private[sql] type SQLContextImpl = SQLContext

/** @inheritdoc */
@deprecated("Use SparkSession.builder instead", "2.0.0")
def getOrCreate(sparkContext: SparkContext): SQLContext = {
SparkSession.builder().sparkContext(sparkContext).getOrCreate().sqlContext
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,57 @@
package org.apache.spark.sql

// scalastyle:off funsuite
import org.scalatest.BeforeAndAfterAll
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach}
import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.sql.functions.sum
import org.apache.spark.SparkContext
import org.apache.spark.sql.functions.{max, sum}

/**
* Test suite for SparkSession implementation binding.
*/
trait SparkSessionBuilderImplementationBindingSuite extends AnyFunSuite with BeforeAndAfterAll {
trait SparkSessionBuilderImplementationBindingSuite
extends AnyFunSuite
with BeforeAndAfterAll
with BeforeAndAfterEach {
// scalastyle:on
protected def configure(builder: SparkSessionBuilder): builder.type = builder

protected def sparkContext: SparkContext
protected def implementationPackageName: String = getClass.getPackageName

private def assertInCorrectPackage[T](obj: T): Unit = {
assert(obj.getClass.getPackageName == implementationPackageName)
}

override protected def beforeEach(): Unit = {
super.beforeEach()
clearSessions()
}

override protected def afterAll(): Unit = {
clearSessions()
super.afterAll()
}

private def clearSessions(): Unit = {
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
}

test("range") {
val session = configure(SparkSession.builder()).getOrCreate()
val session = SparkSession.builder().getOrCreate()
assertInCorrectPackage(session)
import session.implicits._
val df = session.range(10).agg(sum("id")).as[Long]
assert(df.head() == 45)
}

test("sqlContext") {
SparkSession.clearActiveSession()
val ctx = SQLContext.getOrCreate(sparkContext)
assertInCorrectPackage(ctx)
import ctx.implicits._
val df = ctx.createDataset(1 to 11).select(max("value").as[Long])
assert(df.head() == 11)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
*/
package org.apache.spark.sql.connect

import org.apache.spark.sql
import org.apache.spark.sql.SparkSessionBuilder
import org.apache.spark.{sql, SparkContext}
import org.apache.spark.sql.connect.test.{ConnectFunSuite, RemoteSparkSession}

/**
Expand All @@ -27,8 +26,11 @@ class SparkSessionBuilderImplementationBindingSuite
extends ConnectFunSuite
with sql.SparkSessionBuilderImplementationBindingSuite
with RemoteSparkSession {
override protected def configure(builder: SparkSessionBuilder): builder.type = {
override def beforeAll(): Unit = {
// We need to set this configuration because the port used by the server is random.
builder.remote(s"sc://localhost:$serverPort")
System.setProperty("spark.remote", s"sc://localhost:$serverPort")
super.beforeAll()
}

override protected def sparkContext: SparkContext = null
}
Original file line number Diff line number Diff line change
Expand Up @@ -305,29 +305,13 @@ class SQLContext private[sql] (override val sparkSession: SparkSession)
super.jdbc(url, table, theParts)
}
}

object SQLContext extends sql.SQLContextCompanion {

override private[sql] type SQLContextImpl = SQLContext
override private[sql] type SparkContextImpl = SparkContext

/**
* Get the singleton SQLContext if it exists or create a new one.
*
* This function can be used to create a singleton SQLContext object that can be shared across
* the JVM.
*
* If there is an active SQLContext for current thread, it will be returned instead of the
* global one.
*
* @param sparkContext
* The SparkContext. This parameter is not used in Spark Connect.
*
* @since 4.0.0
*/
/** @inheritdoc */
def getOrCreate(sparkContext: SparkContext): SQLContext = {
SparkSession.builder().getOrCreate().sqlContext
}

/** @inheritdoc */
override def setActive(sqlContext: SQLContext): Unit = super.setActive(sqlContext)
}
Original file line number Diff line number Diff line change
Expand Up @@ -378,18 +378,13 @@ class SQLContext private[sql] (override val sparkSession: SparkSession)
}

object SQLContext extends sql.SQLContextCompanion {

override private[sql] type SQLContextImpl = SQLContext
override private[sql] type SparkContextImpl = SparkContext

/** @inheritdoc */
def getOrCreate(sparkContext: SparkContext): SQLContext = {
newSparkSessionBuilder().sparkContext(sparkContext).getOrCreate().sqlContext
}

/** @inheritdoc */
override def setActive(sqlContext: SQLContext): Unit = super.setActive(sqlContext)

/**
* Converts an iterator of Java Beans to InternalRow using the provided bean info & schema. This
* is not related to the singleton, but is a static method for internal use.
Expand Down