diff --git a/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala b/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala index 5e022570d3ca7..200e913b5412e 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/api/SQLImplicits.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{ArrayEncoder, DE * * @since 1.6.0 */ -abstract class SQLImplicits extends LowPrioritySQLImplicits with Serializable { +abstract class SQLImplicits extends EncoderImplicits with Serializable { type DS[U] <: Dataset[U] protected def session: SparkSession @@ -51,8 +51,35 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits with Serializable { } } - // Primitives + /** + * Creates a [[Dataset]] from a local Seq. + * @since 1.6.0 + */ + implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T, DS] = { + new DatasetHolder(session.createDataset(s).asInstanceOf[DS[T]]) + } + + /** + * Creates a [[Dataset]] from an RDD. + * + * @since 1.6.0 + */ + implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T, DS] = + new DatasetHolder(session.createDataset(rdd).asInstanceOf[DS[T]]) + + /** + * An implicit conversion that turns a Scala `Symbol` into a [[org.apache.spark.sql.Column]]. + * @since 1.3.0 + */ + implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) +} +/** + * EncoderImplicits used to implicitly generate SQL Encoders. Note that these functions don't rely + * on or expose `SparkSession`. + */ +trait EncoderImplicits extends LowPrioritySQLImplicits with Serializable { + // Primitives /** @since 1.6.0 */ implicit def newIntEncoder: Encoder[Int] = Encoders.scalaInt @@ -270,28 +297,6 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits with Serializable { /** @since 1.6.1 */ implicit def newProductArrayEncoder[A <: Product: TypeTag]: Encoder[Array[A]] = newArrayEncoder(ScalaReflection.encoderFor[A]) - - /** - * Creates a [[Dataset]] from a local Seq. - * @since 1.6.0 - */ - implicit def localSeqToDatasetHolder[T: Encoder](s: Seq[T]): DatasetHolder[T, DS] = { - new DatasetHolder(session.createDataset(s).asInstanceOf[DS[T]]) - } - - /** - * Creates a [[Dataset]] from an RDD. - * - * @since 1.6.0 - */ - implicit def rddToDatasetHolder[T: Encoder](rdd: RDD[T]): DatasetHolder[T, DS] = - new DatasetHolder(session.createDataset(rdd).asInstanceOf[DS[T]]) - - /** - * An implicit conversion that turns a Scala `Symbol` into a [[org.apache.spark.sql.Column]]. - * @since 1.3.0 - */ - implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name) } /** diff --git a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala index 55477b4dda0c9..b47629cb54396 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/streaming/StatefulProcessor.scala @@ -20,16 +20,25 @@ package org.apache.spark.sql.streaming import java.io.Serializable import org.apache.spark.annotation.{Evolving, Experimental} +import org.apache.spark.sql.api.EncoderImplicits import org.apache.spark.sql.errors.ExecutionErrors /** * Represents the arbitrary stateful logic that needs to be provided by the user to perform * stateful manipulations on keyed streams. + * + * Users can also explicitly use `import implicits._` to access the EncoderImplicits and use the + * state variable APIs relying on implicit encoders. */ @Experimental @Evolving private[sql] abstract class StatefulProcessor[K, I, O] extends Serializable { + // scalastyle:off + // Disable style checker so "implicits" object can start with lowercase i + object implicits extends EncoderImplicits + // scalastyle:on + /** * Handle to the stateful processor that provides access to the state store and other stateful * processing related APIs. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateClusterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateClusterSuite.scala new file mode 100644 index 0000000000000..3e2899f7c6ee7 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateClusterSuite.scala @@ -0,0 +1,211 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.{SparkConf, SparkFunSuite} +import org.apache.spark.sql.{Dataset, Encoders, Row, SparkSession} +import org.apache.spark.sql.LocalSparkSession.withSparkSession +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider +import org.apache.spark.sql.internal.SQLConf + +case class FruitState( + name: String, + count: Long, + family: String +) + +class FruitCountStatefulProcessor(useImplicits: Boolean) + extends StatefulProcessor[String, String, (String, Long, String)] { + import implicits._ + + @transient protected var _fruitState: ValueState[FruitState] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + if (useImplicits) { + _fruitState = getHandle.getValueState[FruitState]("fruitState", TTLConfig.NONE) + } else { + _fruitState = getHandle.getValueState("fruitState", Encoders.product[FruitState], + TTLConfig.NONE) + } + } + + private def getFamily(fruitName: String): String = { + if (fruitName == "orange" || fruitName == "lemon" || fruitName == "lime") { + "citrus" + } else { + "non-citrus" + } + } + + override def handleInputRows(key: String, inputRows: Iterator[String], timerValues: TimerValues): + Iterator[(String, Long, String)] = { + val new_cnt = _fruitState.getOption().map(x => x.count).getOrElse(0L) + inputRows.size + val family = getFamily(key) + _fruitState.update(FruitState(key, new_cnt, family)) + Iterator.single((key, new_cnt, family)) + } +} + +class FruitCountStatefulProcessorWithInitialState(useImplicits: Boolean) + extends StatefulProcessorWithInitialState[String, String, (String, Long, String), String] { + import implicits._ + + @transient protected var _fruitState: ValueState[FruitState] = _ + + override def init(outputMode: OutputMode, timeMode: TimeMode): Unit = { + if (useImplicits) { + _fruitState = getHandle.getValueState[FruitState]("fruitState", TTLConfig.NONE) + } else { + _fruitState = getHandle.getValueState("fruitState", Encoders.product[FruitState], + TTLConfig.NONE) + } + } + + private def getFamily(fruitName: String): String = { + if (fruitName == "orange" || fruitName == "lemon" || fruitName == "lime") { + "citrus" + } else { + "non-citrus" + } + } + + override def handleInitialState(key: String, initialState: String, + timerValues: TimerValues): Unit = { + val new_cnt = _fruitState.getOption().map(x => x.count).getOrElse(0L) + 1 + val family = getFamily(key) + _fruitState.update(FruitState(key, new_cnt, family)) + } + + override def handleInputRows(key: String, inputRows: Iterator[String], timerValues: TimerValues): + Iterator[(String, Long, String)] = { + val new_cnt = _fruitState.getOption().map(x => x.count).getOrElse(0L) + inputRows.size + val family = getFamily(key) + _fruitState.update(FruitState(key, new_cnt, family)) + Iterator.single((key, new_cnt, family)) + } +} + +trait TransformWithStateClusterSuiteBase extends SparkFunSuite { + def getSparkConf(): SparkConf = { + val conf = new SparkConf() + .setMaster("local-cluster[2, 2, 1024]") + .set(SQLConf.STATE_STORE_PROVIDER_CLASS.key, + classOf[RocksDBStateStoreProvider].getCanonicalName) + .set(SQLConf.SHUFFLE_PARTITIONS.key, + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) + .set(SQLConf.STREAMING_STOP_TIMEOUT, 5000L) + conf + } + + // Start a new test with cluster containing two executors and streaming stop timeout set to 5s + val testSparkConf = getSparkConf() + + protected def testWithAndWithoutImplicitEncoders(name: String) + (func: (SparkSession, Boolean) => Any): Unit = { + Seq(false, true).foreach { useImplicits => + test(s"$name - useImplicits = $useImplicits") { + withSparkSession(SparkSession.builder().config(testSparkConf).getOrCreate()) { spark => + func(spark, useImplicits) + } + } + } + } +} + +/** + * Test suite spawning local cluster with multiple executors to test serde of stateful + * processors along with use of implicit encoders, if applicable in transformWithState operator. + */ +class TransformWithStateClusterSuite extends StreamTest with TransformWithStateClusterSuiteBase { + testWithAndWithoutImplicitEncoders("streaming with transformWithState - " + + "without initial state") { (spark, useImplicits) => + import spark.implicits._ + val input = MemoryStream(Encoders.STRING, spark.sqlContext) + val agg = input.toDS() + .groupByKey(x => x) + .transformWithState(new FruitCountStatefulProcessor(useImplicits), + TimeMode.None(), + OutputMode.Update() + ) + + val query = agg.writeStream + .format("memory") + .outputMode("update") + .queryName("output") + .start() + + input.addData("apple", "apple", "orange", "orange", "orange") + query.processAllAvailable() + + checkAnswer(spark.sql("select * from output"), + Seq(Row("apple", 2, "non-citrus"), + Row("orange", 3, "citrus"))) + + input.addData("lemon", "lime") + query.processAllAvailable() + checkAnswer(spark.sql("select * from output"), + Seq(Row("apple", 2, "non-citrus"), + Row("orange", 3, "citrus"), + Row("lemon", 1, "citrus"), + Row("lime", 1, "citrus"))) + + query.stop() + } + + testWithAndWithoutImplicitEncoders("streaming with transformWithState - " + + "with initial state") { (spark, useImplicits) => + import spark.implicits._ + + val fruitCountInitialDS: Dataset[String] = Seq( + "apple", "apple", "orange", "orange", "orange").toDS() + + val fruitCountInitial = fruitCountInitialDS + .groupByKey(x => x) + + val input = MemoryStream(Encoders.STRING, spark.sqlContext) + val agg = input.toDS() + .groupByKey(x => x) + .transformWithState(new FruitCountStatefulProcessorWithInitialState(useImplicits), + TimeMode.None(), + OutputMode.Update(), fruitCountInitial) + + val query = agg.writeStream + .format("memory") + .outputMode("update") + .queryName("output") + .start() + + input.addData("apple", "apple", "orange", "orange", "orange") + query.processAllAvailable() + + checkAnswer(spark.sql("select * from output"), + Seq(Row("apple", 4, "non-citrus"), + Row("orange", 6, "citrus"))) + + input.addData("lemon", "lime") + query.processAllAvailable() + checkAnswer(spark.sql("select * from output"), + Seq(Row("apple", 4, "non-citrus"), + Row("orange", 6, "citrus"), + Row("lemon", 1, "citrus"), + Row("lime", 1, "citrus"))) + + query.stop() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala index 29f40df83f24a..c7ad8536ebd46 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateInitialStateSuite.scala @@ -47,6 +47,8 @@ case class UnionUnflattenInitialStateRow( abstract class StatefulProcessorWithInitialStateTestClass[V] extends StatefulProcessorWithInitialState[ String, InitInputRow, (String, String, Double), V] { + import implicits._ + @transient var _valState: ValueState[Double] = _ @transient var _listState: ListState[Double] = _ @transient var _mapState: MapState[Double, Int] = _ @@ -54,13 +56,9 @@ abstract class StatefulProcessorWithInitialStateTestClass[V] override def init( outputMode: OutputMode, timeMode: TimeMode): Unit = { - _valState = getHandle.getValueState[Double]("testValueInit", Encoders.scalaDouble, - TTLConfig.NONE) - _listState = getHandle.getListState[Double]("testListInit", Encoders.scalaDouble, - TTLConfig.NONE) - _mapState = getHandle.getMapState[Double, Int]( - "testMapInit", Encoders.scalaDouble, Encoders.scalaInt, - TTLConfig.NONE) + _valState = getHandle.getValueState[Double]("testValueInit", TTLConfig.NONE) + _listState = getHandle.getListState[Double]("testListInit", TTLConfig.NONE) + _mapState = getHandle.getMapState[Double, Int]("testMapInit", TTLConfig.NONE) } override def handleInputRows( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 91a47645f4179..d4c5a735ce6fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -45,13 +45,13 @@ object TransformWithStateSuiteUtils { class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { + import implicits._ @transient protected var _countState: ValueState[Long] = _ override def init( outputMode: OutputMode, timeMode: TimeMode): Unit = { - _countState = getHandle.getValueState[Long]("countState", - Encoders.scalaLong, TTLConfig.NONE) + _countState = getHandle.getValueState[Long]("countState", TTLConfig.NONE) } override def handleInputRows( @@ -72,12 +72,13 @@ class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (S class RunningCountStatefulProcessorWithTTL extends StatefulProcessor[String, String, (String, String)] with Logging { + import implicits._ @transient protected var _countState: ValueState[Long] = _ override def init( outputMode: OutputMode, timeMode: TimeMode): Unit = { - _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong, + _countState = getHandle.getValueState[Long]("countState", TTLConfig(Duration.ofMillis(1000))) } @@ -384,20 +385,32 @@ class RunningCountStatefulProcessorWithError extends RunningCountStatefulProcess } // class for verify state schema is correctly written for all state var types -class StatefulProcessorWithCompositeTypes extends RunningCountStatefulProcessor { +class StatefulProcessorWithCompositeTypes(useImplicits: Boolean) + extends RunningCountStatefulProcessor { + import implicits._ @transient private var _listState: ListState[TestClass] = _ @transient private var _mapState: MapState[POJOTestClass, String] = _ override def init( outputMode: OutputMode, timeMode: TimeMode): Unit = { - _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong, - TTLConfig.NONE) - _listState = getHandle.getListState[TestClass]( - "listState", Encoders.product[TestClass], TTLConfig.NONE) - _mapState = getHandle.getMapState[POJOTestClass, String]( - "mapState", Encoders.bean(classOf[POJOTestClass]), Encoders.STRING, - TTLConfig.NONE) + + if (useImplicits) { + _countState = getHandle.getValueState[Long]("countState", TTLConfig.NONE) + _listState = getHandle.getListState[TestClass]( + "listState", TTLConfig.NONE) + _mapState = getHandle.getMapState[POJOTestClass, String]( + "mapState", Encoders.bean(classOf[POJOTestClass]), Encoders.STRING, + TTLConfig.NONE) + } else { + _countState = getHandle.getValueState[Long]("countState", Encoders.scalaLong, + TTLConfig.NONE) + _listState = getHandle.getListState[TestClass]( + "listState", Encoders.product[TestClass], TTLConfig.NONE) + _mapState = getHandle.getMapState[POJOTestClass, String]( + "mapState", Encoders.bean(classOf[POJOTestClass]), Encoders.STRING, + TTLConfig.NONE) + } } } @@ -1037,85 +1050,87 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } - test("transformWithState - verify StateSchemaV3 writes " + - "correct SQL schema of key/value") { - withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> - classOf[RocksDBStateStoreProvider].getName, - SQLConf.SHUFFLE_PARTITIONS.key -> - TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { - withTempDir { checkpointDir => - val metadataPathPostfix = "state/0/_stateSchema/default" - val stateSchemaPath = new Path(checkpointDir.toString, - s"$metadataPathPostfix") - val hadoopConf = spark.sessionState.newHadoopConf() - val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) - - val keySchema = new StructType().add("value", StringType) - val schema0 = StateStoreColFamilySchema( - "countState", - keySchema, - new StructType().add("value", LongType, false), - Some(NoPrefixKeyStateEncoderSpec(keySchema)), - None - ) - val schema1 = StateStoreColFamilySchema( - "listState", - keySchema, - new StructType() + Seq(false, true).foreach { useImplicits => + test("transformWithState - verify StateSchemaV3 writes " + + s"correct SQL schema of key/value with useImplicits=$useImplicits") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName, + SQLConf.SHUFFLE_PARTITIONS.key -> + TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS.toString) { + withTempDir { checkpointDir => + val metadataPathPostfix = "state/0/_stateSchema/default" + val stateSchemaPath = new Path(checkpointDir.toString, + s"$metadataPathPostfix") + val hadoopConf = spark.sessionState.newHadoopConf() + val fm = CheckpointFileManager.create(stateSchemaPath, hadoopConf) + + val keySchema = new StructType().add("value", StringType) + val schema0 = StateStoreColFamilySchema( + "countState", + keySchema, + new StructType().add("value", LongType, false), + Some(NoPrefixKeyStateEncoderSpec(keySchema)), + None + ) + val schema1 = StateStoreColFamilySchema( + "listState", + keySchema, + new StructType() .add("id", LongType, false) .add("name", StringType), - Some(NoPrefixKeyStateEncoderSpec(keySchema)), - None - ) - - val userKeySchema = new StructType() - .add("id", IntegerType, false) - .add("name", StringType) - val compositeKeySchema = new StructType() - .add("key", new StructType().add("value", StringType)) - .add("userKey", userKeySchema) - val schema2 = StateStoreColFamilySchema( - "mapState", - compositeKeySchema, - new StructType().add("value", StringType), - Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), - Option(userKeySchema) - ) - - val inputData = MemoryStream[String] - val result = inputData.toDS() - .groupByKey(x => x) - .transformWithState(new StatefulProcessorWithCompositeTypes(), - TimeMode.None(), - OutputMode.Update()) - - testStream(result, OutputMode.Update())( - StartStream(checkpointLocation = checkpointDir.getCanonicalPath), - AddData(inputData, "a", "b"), - CheckNewAnswer(("a", "1"), ("b", "1")), - Execute { q => - q.lastProgress.runId - val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath - val providerId = StateStoreProviderId(StateStoreId( - checkpointDir.getCanonicalPath, 0, 0), q.lastProgress.runId) - val checker = new StateSchemaCompatibilityChecker(providerId, - hadoopConf, Some(schemaFilePath)) - val colFamilySeq = checker.readSchemaFile() - - assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == - q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt) - assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == - q.lastProgress.stateOperators.head.customMetrics.get("numListStateVars").toInt) - assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == - q.lastProgress.stateOperators.head.customMetrics.get("numMapStateVars").toInt) - - assert(colFamilySeq.length == 3) - assert(colFamilySeq.map(_.toString).toSet == Set( - schema0, schema1, schema2 - ).map(_.toString)) - }, - StopStream - ) + Some(NoPrefixKeyStateEncoderSpec(keySchema)), + None + ) + + val userKeySchema = new StructType() + .add("id", IntegerType, false) + .add("name", StringType) + val compositeKeySchema = new StructType() + .add("key", new StructType().add("value", StringType)) + .add("userKey", userKeySchema) + val schema2 = StateStoreColFamilySchema( + "mapState", + compositeKeySchema, + new StructType().add("value", StringType), + Some(PrefixKeyScanStateEncoderSpec(compositeKeySchema, 1)), + Option(userKeySchema) + ) + + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new StatefulProcessorWithCompositeTypes(useImplicits), + TimeMode.None(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream(checkpointLocation = checkpointDir.getCanonicalPath), + AddData(inputData, "a", "b"), + CheckNewAnswer(("a", "1"), ("b", "1")), + Execute { q => + q.lastProgress.runId + val schemaFilePath = fm.list(stateSchemaPath).toSeq.head.getPath + val providerId = StateStoreProviderId(StateStoreId( + checkpointDir.getCanonicalPath, 0, 0), q.lastProgress.runId) + val checker = new StateSchemaCompatibilityChecker(providerId, + hadoopConf, Some(schemaFilePath)) + val colFamilySeq = checker.readSchemaFile() + + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numValueStateVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numListStateVars").toInt) + assert(TransformWithStateSuiteUtils.NUM_SHUFFLE_PARTITIONS == + q.lastProgress.stateOperators.head.customMetrics.get("numMapStateVars").toInt) + + assert(colFamilySeq.length == 3) + assert(colFamilySeq.map(_.toString).toSet == Set( + schema0, schema1, schema2 + ).map(_.toString)) + }, + StopStream + ) + } } } }