Skip to content

Commit

Permalink
[SPARK-50513][SS][SQL] Split EncoderImplicits from SQLImplicits and p…
Browse files Browse the repository at this point in the history
…rovide helper object within StatefulProcessor to access underlying SQL Encoder related implicit functions

### What changes were proposed in this pull request?
Split EncoderImplicits from SQLImplicits and provide helper object within StatefulProcessor to access underlying SQL Encoder related implicit functions

### Why are the changes needed?
Without this, we cannot handle the implicit encoder APIs on the executor. We would run into a NPE since the spark session is not available on the executors. One option is to pass the `SparkSession` or `SQLImplicits` directly to the `StatefulProcessor`. However, this risks exposing some methods relying on `SparkSession` again.

```
Caused by: Task 11 in stage 4.0 failed 4 times, most recent failure: Lost task 11.3 in stage 4.0 (TID 57) (10.68.181.194 executor 2): java.lang.NullPointerException: Cannot invoke "org.apache.spark.sql.SparkSession.implicits()" because the return value of "$line94e51c7e8a51415d827a613fcad00cc19.$read$$iw$$iw.spark()" is null
	at $line94e51c7e8a51415d827a613fcad00cc152.$read$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$$iw$FruitCountStatefulProcessor.init(command-412580759966543:18)
	at org.apache.spark.sql.execution.streaming.TransformWithStateExec.processDataWithInitialState(TransformWithStateExec.scala:732)
	at org.apache.spark.sql.execution.streaming.TransformWithStateExec.$anonfun$doExecute$2(TransformWithStateExec.scala:625)
```

Hence, we do 2 things:
- split the encoder related functions that don't rely on `SparkSession` into a separate trait
- expose a helper `implicits` object within the `StatefulProcessor` interface that users can import directly, providing access to only the necessary functions

### Does this PR introduce _any_ user-facing change?
Yes

### How was this patch tested?
Added unit tests

```
[info] Run completed in 9 minutes, 30 seconds.
[info] Total number of tests run: 128
[info] Suites: completed 1, aborted 0
[info] Tests: succeeded 128, failed 0, canceled 0, ignored 0, pending 0
[info] All tests passed.
```

### Was this patch authored or co-authored using generative AI tooling?
No

Closes apache#49099 from anishshri-db/task/SPARK-50513.

Authored-by: Anish Shrigondekar <[email protected]>
Signed-off-by: Jungtaek Lim <[email protected]>
  • Loading branch information
anishshri-db authored and HeartSaVioR committed Dec 10, 2024
1 parent 02ebf12 commit 03c5799
Show file tree
Hide file tree
Showing 5 changed files with 357 additions and 119 deletions.
53 changes: 29 additions & 24 deletions sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, DE
*
* @since 1.6.0
*/
abstract class SQLImplicits extends LowPrioritySQLImplicits with Serializable {
abstract class SQLImplicits extends EncoderImplicits with Serializable {
type DS[U] <: Dataset[U]

protected def session: SparkSession
Expand All @@ -51,8 +51,35 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits with Serializable {
}
}

// Primitives
/**
* Creates a [[Dataset]] from a local Seq.
* @since 1.6.0
*/
implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T, DS] = {
new DatasetHolder(session.createDataset(s).asInstanceOf[DS[T]])
}

/**
* Creates a [[Dataset]] from an RDD.
*
* @since 1.6.0
*/
implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T, DS] =
new DatasetHolder(session.createDataset(rdd).asInstanceOf[DS[T]])

/**
* An implicit conversion that turns a Scala `Symbol` into a [[org.apache.spark.sql.Column]].
* @since 1.3.0
*/
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
}

/**
* EncoderImplicits used to implicitly generate SQL Encoders. Note that these functions don't rely
* on or expose `SparkSession`.
*/
trait EncoderImplicits extends LowPrioritySQLImplicits with Serializable {
// Primitives
/** @since 1.6.0 */
implicit def newIntEncoder: Encoder[Int] = Encoders.scalaInt

Expand Down Expand Up @@ -270,28 +297,6 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits with Serializable {
/** @since 1.6.1 */
implicit def newProductArrayEncoder[A <: Product: TypeTag]: Encoder[Array[A]] =
newArrayEncoder(ScalaReflection.encoderFor[A])

/**
* Creates a [[Dataset]] from a local Seq.
* @since 1.6.0
*/
implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T, DS] = {
new DatasetHolder(session.createDataset(s).asInstanceOf[DS[T]])
}

/**
* Creates a [[Dataset]] from an RDD.
*
* @since 1.6.0
*/
implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T, DS] =
new DatasetHolder(session.createDataset(rdd).asInstanceOf[DS[T]])

/**
* An implicit conversion that turns a Scala `Symbol` into a [[org.apache.spark.sql.Column]].
* @since 1.3.0
*/
implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,25 @@ package org.apache.spark.sql.streaming
import java.io.Serializable

import org.apache.spark.annotation.{Evolving, Experimental}
import org.apache.spark.sql.api.EncoderImplicits
import org.apache.spark.sql.errors.ExecutionErrors

/**
* Represents the arbitrary stateful logic that needs to be provided by the user to perform
* stateful manipulations on keyed streams.
*
* Users can also explicitly use `import implicits._` to access the EncoderImplicits and use the
* state variable APIs relying on implicit encoders.
*/
@Experimental
@Evolving
private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable {

// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
object implicits extends EncoderImplicits
// scalastyle:on

/**
* Handle to the stateful processor that provides access to the state store and other stateful
* processing related APIs.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.streaming

import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.sql.{Dataset, Encoders, Row, SparkSession}
import org.apache.spark.sql.LocalSparkSession.withSparkSession
import org.apache.spark.sql.execution.streaming.MemoryStream
import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider
import org.apache.spark.sql.internal.SQLConf

case class FruitState(
name: String,
count: Long,
family: String
)

class FruitCountStatefulProcessor(useImplicits: Boolean)
extends StatefulProcessor[String, String, (String, Long, String)] {
import implicits._

@transient protected var _fruitState: ValueState[FruitState] = _

override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
if (useImplicits) {
_fruitState = getHandle.getValueState[FruitState]("fruitState", TTLConfig.NONE)
} else {
_fruitState = getHandle.getValueState("fruitState", Encoders.product[FruitState],
TTLConfig.NONE)
}
}

private def getFamily(fruitName: String): String = {
if (fruitName == "orange" || fruitName == "lemon" || fruitName == "lime") {
"citrus"
} else {
"non-citrus"
}
}

override def handleInputRows(key: String, inputRows: Iterator[String], timerValues: TimerValues):
Iterator[(String, Long, String)] = {
val new_cnt = _fruitState.getOption().map(x => x.count).getOrElse(0L) + inputRows.size
val family = getFamily(key)
_fruitState.update(FruitState(key, new_cnt, family))
Iterator.single((key, new_cnt, family))
}
}

class FruitCountStatefulProcessorWithInitialState(useImplicits: Boolean)
extends StatefulProcessorWithInitialState[String, String, (String, Long, String), String] {
import implicits._

@transient protected var _fruitState: ValueState[FruitState] = _

override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = {
if (useImplicits) {
_fruitState = getHandle.getValueState[FruitState]("fruitState", TTLConfig.NONE)
} else {
_fruitState = getHandle.getValueState("fruitState", Encoders.product[FruitState],
TTLConfig.NONE)
}
}

private def getFamily(fruitName: String): String = {
if (fruitName == "orange" || fruitName == "lemon" || fruitName == "lime") {
"citrus"
} else {
"non-citrus"
}
}

override def handleInitialState(key: String, initialState: String,
timerValues: TimerValues): Unit = {
val new_cnt = _fruitState.getOption().map(x => x.count).getOrElse(0L) + 1
val family = getFamily(key)
_fruitState.update(FruitState(key, new_cnt, family))
}

override def handleInputRows(key: String, inputRows: Iterator[String], timerValues: TimerValues):
Iterator[(String, Long, String)] = {
val new_cnt = _fruitState.getOption().map(x => x.count).getOrElse(0L) + inputRows.size
val family = getFamily(key)
_fruitState.update(FruitState(key, new_cnt, family))
Iterator.single((key, new_cnt, family))
}
}

trait TransformWithStateClusterSuiteBase extends SparkFunSuite {
def getSparkConf(): SparkConf = {
val conf = new SparkConf()
.setMaster("local-cluster[2, 2, 1024]")
.set(SQLConf.STATE_STORE_PROVIDER_CLASS.key,
classOf[RocksDBStateStoreProvider].getCanonicalName)
.set(SQLConf.SHUFFLE_PARTITIONS.key,
TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString)
.set(SQLConf.STREAMING_STOP_TIMEOUT, 5000L)
conf
}

// Start a new test with cluster containing two executors and streaming stop timeout set to 5s
val testSparkConf = getSparkConf()

protected def testWithAndWithoutImplicitEncoders(name: String)
(func: (SparkSession, Boolean) => Any): Unit = {
Seq(false, true).foreach { useImplicits =>
test(s"$name - useImplicits = $useImplicits") {
withSparkSession(SparkSession.builder().config(testSparkConf).getOrCreate()) { spark =>
func(spark, useImplicits)
}
}
}
}
}

/**
* Test suite spawning local cluster with multiple executors to test serde of stateful
* processors along with use of implicit encoders, if applicable in transformWithState operator.
*/
class TransformWithStateClusterSuite extends StreamTest with TransformWithStateClusterSuiteBase {
testWithAndWithoutImplicitEncoders("streaming with transformWithState - " +
"without initial state") { (spark, useImplicits) =>
import spark.implicits._
val input = MemoryStream(Encoders.STRING, spark.sqlContext)
val agg = input.toDS()
.groupByKey(x => x)
.transformWithState(new FruitCountStatefulProcessor(useImplicits),
TimeMode.None(),
OutputMode.Update()
)

val query = agg.writeStream
.format("memory")
.outputMode("update")
.queryName("output")
.start()

input.addData("apple", "apple", "orange", "orange", "orange")
query.processAllAvailable()

checkAnswer(spark.sql("select * from output"),
Seq(Row("apple", 2, "non-citrus"),
Row("orange", 3, "citrus")))

input.addData("lemon", "lime")
query.processAllAvailable()
checkAnswer(spark.sql("select * from output"),
Seq(Row("apple", 2, "non-citrus"),
Row("orange", 3, "citrus"),
Row("lemon", 1, "citrus"),
Row("lime", 1, "citrus")))

query.stop()
}

testWithAndWithoutImplicitEncoders("streaming with transformWithState - " +
"with initial state") { (spark, useImplicits) =>
import spark.implicits._

val fruitCountInitialDS: Dataset[String] = Seq(
"apple", "apple", "orange", "orange", "orange").toDS()

val fruitCountInitial = fruitCountInitialDS
.groupByKey(x => x)

val input = MemoryStream(Encoders.STRING, spark.sqlContext)
val agg = input.toDS()
.groupByKey(x => x)
.transformWithState(new FruitCountStatefulProcessorWithInitialState(useImplicits),
TimeMode.None(),
OutputMode.Update(), fruitCountInitial)

val query = agg.writeStream
.format("memory")
.outputMode("update")
.queryName("output")
.start()

input.addData("apple", "apple", "orange", "orange", "orange")
query.processAllAvailable()

checkAnswer(spark.sql("select * from output"),
Seq(Row("apple", 4, "non-citrus"),
Row("orange", 6, "citrus")))

input.addData("lemon", "lime")
query.processAllAvailable()
checkAnswer(spark.sql("select * from output"),
Seq(Row("apple", 4, "non-citrus"),
Row("orange", 6, "citrus"),
Row("lemon", 1, "citrus"),
Row("lime", 1, "citrus")))

query.stop()
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,20 +47,18 @@ case class UnionUnflattenInitialStateRow(
abstract class StatefulProcessorWithInitialStateTestClass[V]
extends StatefulProcessorWithInitialState[
String, InitInputRow, (String, String, Double), V] {
import implicits._

@transient var _valState: ValueState[Double] = _
@transient var _listState: ListState[Double] = _
@transient var _mapState: MapState[Double, Int] = _

override def init(
outputMode: OutputMode,
timeMode: TimeMode): Unit = {
_valState = getHandle.getValueState[Double]("testValueInit", Encoders.scalaDouble,
TTLConfig.NONE)
_listState = getHandle.getListState[Double]("testListInit", Encoders.scalaDouble,
TTLConfig.NONE)
_mapState = getHandle.getMapState[Double, Int](
"testMapInit", Encoders.scalaDouble, Encoders.scalaInt,
TTLConfig.NONE)
_valState = getHandle.getValueState[Double]("testValueInit", TTLConfig.NONE)
_listState = getHandle.getListState[Double]("testListInit", TTLConfig.NONE)
_mapState = getHandle.getMapState[Double, Int]("testMapInit", TTLConfig.NONE)
}

override def handleInputRows(
Expand Down
Loading

0 comments on commit 03c5799

Please sign in to comment.