From 575cfba50ee693678bc51033bbd6d9cb618c1b2a Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 22 May 2024 11:21:06 -0700 Subject: [PATCH 01/11] mapstateimpl --- .../sql/execution/streaming/MapStateImpl.scala | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index c58f32ed756db..b6245f1750cb9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder +import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} import org.apache.spark.sql.streaming.MapState import org.apache.spark.sql.types.{BinaryType, StructType} @@ -30,18 +31,12 @@ class MapStateImpl[K, V]( userKeyEnc: Encoder[K], valEncoder: Encoder[V]) extends MapState[K, V] with Logging { - // Pack grouping key and user key together as a prefixed composite key - private val schemaForCompositeKeyRow: StructType = - new StructType() - .add("key", BinaryType) - .add("userKey", BinaryType) - private val schemaForValueRow: StructType = new StructType().add("value", BinaryType) private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = new CompositeKeyStateEncoder( - keySerializer, userKeyEnc, valEncoder, schemaForCompositeKeyRow, stateName) + keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName) - store.createColFamilyIfAbsent(stateName, schemaForCompositeKeyRow, schemaForValueRow, - PrefixKeyScanStateEncoderSpec(schemaForCompositeKeyRow, 1)) + store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1)) /** Whether state exists or not. */ override def exists(): Boolean = { From b9fd42f1b1ee5e5782f8352b5c5ffc9ad7d89d66 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Wed, 22 May 2024 13:48:06 -0700 Subject: [PATCH 02/11] adding support for StateSchemaV3 --- .../execution/streaming/ListStateImpl.scala | 7 +- .../streaming/ListStateImplWithTTL.scala | 8 ++- .../execution/streaming/MapStateImpl.scala | 9 ++- .../streaming/MapStateImplWithTTL.scala | 8 ++- .../StatefulProcessorHandleImpl.scala | 9 +++ .../execution/streaming/ValueStateImpl.scala | 7 +- .../streaming/ValueStateImplWithTTL.scala | 8 ++- .../state/HDFSBackedStateStoreProvider.scala | 4 ++ .../state/RocksDBStateStoreProvider.scala | 9 +++ .../streaming/state/SchemaHelper.scala | 64 +++++++++++++++++++ .../streaming/state/StateStore.scala | 29 ++++++++- 11 files changed, 142 insertions(+), 20 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 56c9d2664d9e2..fb8a18db860cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ListState /** @@ -44,8 +44,9 @@ class ListStateImpl[S]( private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true) + val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) + store.createColFamilyIfAbsent(columnFamilyMetadataV1) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index dc72f8bcd5600..350b70b28f7aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{ListState, TTLConfig} import org.apache.spark.util.NextIterator @@ -52,11 +52,13 @@ class ListStateImplWithTTL[S]( private lazy val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) + val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), true) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), useMultipleValuesPerKey = true) + store.createColFamilyIfAbsent(columnFamilyMetadataV1) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index b6245f1750cb9..99dfc763e9fd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} import org.apache.spark.sql.streaming.MapState import org.apache.spark.sql.types.{BinaryType, StructType} @@ -35,8 +35,11 @@ class MapStateImpl[K, V]( private val stateTypesEncoder = new CompositeKeyStateEncoder( keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName) - store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1)) + val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) + + store.createColFamilyIfAbsent(columnFamilyMetadataV1) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala index 2ab06f36dd5f7..fe82e37035743 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{MapState, TTLConfig} import org.apache.spark.util.NextIterator @@ -55,11 +55,13 @@ class MapStateImplWithTTL[K, V]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) + val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, + PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1)) + store.createColFamilyIfAbsent(columnFamilyMetadataV1) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index f8948b07457e5..1b32aa066d69b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -97,6 +97,9 @@ class StatefulProcessorHandleImpl( private[sql] val stateVariables: util.List[StateVariableInfo] = new util.ArrayList[StateVariableInfo]() + var columnFamilyMetadatas: mutable.ListBuffer[ColumnFamilyMetadataV1] = + mutable.ListBuffer.empty[ColumnFamilyMetadataV1] + private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" private def buildQueryInfo(): QueryInfo = { @@ -135,6 +138,7 @@ class StatefulProcessorHandleImpl( stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) incrementMetric("numValueStateVars") val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + columnFamilyMetadatas.addOne(resultState.columnFamilyMetadataV1) resultState } @@ -151,6 +155,7 @@ class StatefulProcessorHandleImpl( keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numValueStateWithTTLVars") ttlStates.add(valueStateWithTTL) + columnFamilyMetadatas.addOne(valueStateWithTTL.columnFamilyMetadataV1) valueStateWithTTL } @@ -248,6 +253,7 @@ class StatefulProcessorHandleImpl( stateVariables.add(new StateVariableInfo(stateName, ListState, false)) incrementMetric("numListStateVars") val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) + columnFamilyMetadatas.addOne(resultState.columnFamilyMetadataV1) resultState } @@ -280,6 +286,7 @@ class StatefulProcessorHandleImpl( keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numListStateWithTTLVars") ttlStates.add(listStateWithTTL) + columnFamilyMetadatas.addOne(listStateWithTTL.columnFamilyMetadataV1) listStateWithTTL } @@ -292,6 +299,7 @@ class StatefulProcessorHandleImpl( stateVariables.add(new StateVariableInfo(stateName, MapState, false)) incrementMetric("numMapStateVars") val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) + columnFamilyMetadatas.addOne(resultState.columnFamilyMetadataV1) resultState } @@ -309,6 +317,7 @@ class StatefulProcessorHandleImpl( valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numMapStateWithTTLVars") ttlStates.add(mapStateWithTTL) + columnFamilyMetadatas.addOne(mapStateWithTTL.columnFamilyMetadataV1) mapStateWithTTL } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index d916011245c00..7a02a340b3cd9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.ValueState /** @@ -42,11 +42,12 @@ class ValueStateImpl[S]( private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) + val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + store.createColFamilyIfAbsent(columnFamilyMetadataV1) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index 0ed5a6f29a984..cc33e7fae694c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.{TTLConfig, ValueState} /** @@ -49,11 +49,13 @@ class ValueStateImplWithTTL[S]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) + val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, + NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, - NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA)) + store.createColFamilyIfAbsent(columnFamilyMetadataV1) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 543cd74c489d0..9a298ae6d5aab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -245,6 +245,10 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with colFamilyName: String): Unit = { throw StateStoreErrors.unsupportedOperationException("merge", providerName) } + + override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilyMetadataV1): Unit = { + throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) + } } def getMetricsForProvider(): Map[String, Long] = synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index e7fc9f56dd9eb..82bf20365e52c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -264,6 +264,15 @@ private[sql] class RocksDBStateStoreProvider keyValueEncoderMap.remove(colFamilyName) result } + + override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilyMetadataV1): Unit = { + createColFamilyIfAbsent( + colFamilyMetadata.columnFamilyName, + colFamilyMetadata.keySchema, + colFamilyMetadata.valueSchema, + colFamilyMetadata.keyStateEncoderSpec, + colFamilyMetadata.multipleValuesPerKey) + } } override def init( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index 2eef3d9fc22ed..ad30f186e636a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -20,11 +20,34 @@ package org.apache.spark.sql.execution.streaming.state import java.io.StringReader import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream} +import org.json4s.JsonAST +import org.json4s.JsonAST._ +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.sql.execution.streaming.MetadataVersionUtil import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils +class ColumnFamilyMetadataV1( + val columnFamilyName: String, + val keySchema: StructType, + val valueSchema: StructType, + val keyStateEncoderSpec: KeyStateEncoderSpec, + val multipleValuesPerKey: Boolean) { + def jsonValue: JsonAST.JObject = { + ("columnFamilyName" -> JString(columnFamilyName)) ~ + ("keySchema" -> keySchema.jsonValue) ~ + ("valueSchema" -> valueSchema.jsonValue) ~ + ("keyStateEncoderSpec" -> keyStateEncoderSpec.jsonValue) ~ + ("multipleValuesPerKey" -> JBool(multipleValuesPerKey)) + } + + def json: String = { + compact(render(jsonValue)) + } +} + /** * Helper classes for reading/writing state schema. */ @@ -68,6 +91,16 @@ object SchemaHelper { } } + class SchemaV3Reader { + def read(inputStream: FSDataInputStream): ColumnFamilyMetadataV1 = { + val buf = new StringBuilder + val numMetadataChunks = inputStream.readInt() + (0 until numMetadataChunks).foreach(_ => buf.append(inputStream.readUTF())) + val colFamilyMetadataStr = buf.toString() + ColumnFamilyMetadataV1.fromString(colFamilyMetadataStr) + } + } + trait SchemaWriter { val version: Int @@ -144,4 +177,35 @@ object SchemaHelper { } } } + + /** + * Schema writer for schema version 3. Because this writer writes out ColFamilyMetadatas + * instead of key and value schemas, it is not compatible with the SchemaWriter interface. + */ + class SchemaV3Writer { + val version: Int = 3 + + // 2^16 - 1 bytes + final val MAX_UTF_CHUNK_SIZE = 65535 + + def writeSchema( + metadatas: Array[ColumnFamilyMetadataV1], + outputStream: FSDataOutputStream): Unit = { + val buf = new Array[Char](MAX_UTF_CHUNK_SIZE) + + // DataOutputStream.writeUTF can't write a string at once + // if the size exceeds 65535 (2^16 - 1) bytes. + // Each metadata consists of multiple chunks in schema version 3. + metadatas.foreach{ metadata => + val metadataJson = metadata.json + val numMetadataChunks = (metadataJson.length - 1) / MAX_UTF_CHUNK_SIZE + 1 + val metadataStringReader = new StringReader(metadataJson) + outputStream.writeInt(numMetadataChunks) + (0 until numMetadataChunks).foreach { _ => + val numRead = metadataStringReader.read(buf, 0, MAX_UTF_CHUNK_SIZE) + outputStream.writeUTF(new String(buf, 0, numRead)) + } + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 8c2170abe3116..ac9f609976b62 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -28,6 +28,9 @@ import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path +import org.json4s.{JInt, JsonAST, JString} +import org.json4s.JsonAST.JObject +import org.json4s.JsonDSL._ import org.apache.spark.{SparkContext, SparkEnv, SparkUnsupportedOperationException} import org.apache.spark.internal.{Logging, LogKeys, MDC} @@ -133,6 +136,10 @@ trait StateStore extends ReadStateStore { useMultipleValuesPerKey: Boolean = false, isInternal: Boolean = false): Unit + def createColFamilyIfAbsent( + colFamilyMetadata: ColumnFamilyMetadataV1 + ) + /** * Put a new non-null value for a non-null key. Implementations must be aware that the UnsafeRows * in the params can be reused, and must make copies of the data as needed for persistence. @@ -289,9 +296,16 @@ class InvalidUnsafeRowException(error: String) "among restart. For the first case, you can try to restart the application without " + s"checkpoint or use the legacy Spark version to process the streaming state.\n$error", null) -sealed trait KeyStateEncoderSpec +sealed trait KeyStateEncoderSpec { + def jsonValue: JsonAST.JObject +} -case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec +case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec { + override def jsonValue: JsonAST.JObject = { + ("keyStateEncoderType" -> JString("NoPrefixKeyStateEncoderSpec")) ~ + ("keySchema" -> keySchema.jsonValue) + } +} case class PrefixKeyScanStateEncoderSpec( keySchema: StructType, @@ -299,6 +313,11 @@ case class PrefixKeyScanStateEncoderSpec( if (numColsPrefixKey == 0 || numColsPrefixKey >= keySchema.length) { throw StateStoreErrors.incorrectNumOrderingColsForPrefixScan(numColsPrefixKey.toString) } + override def jsonValue: JsonAST.JObject = { + ("keyStateEncoderType" -> JString("PrefixKeyScanStateEncoderSpec")) ~ + ("keySchema" -> keySchema.jsonValue) ~ + ("numColsPrefixKey" -> JInt(numColsPrefixKey)) + } } /** Encodes rows so that they can be range-scanned based on orderingOrdinals */ @@ -308,6 +327,12 @@ case class RangeKeyScanStateEncoderSpec( if (orderingOrdinals.isEmpty || orderingOrdinals.length > keySchema.length) { throw StateStoreErrors.incorrectNumOrderingColsForRangeScan(orderingOrdinals.length.toString) } + + override def jsonValue: JObject = { + ("keyStateEncoderType" -> JString("RangeKeyScanStateEncoderSpec")) ~ + ("keySchema" -> keySchema.jsonValue) ~ + ("orderingOrdinals" -> orderingOrdinals.map(JInt(_))) + } } /** From a621ba8819cc8990d64e63e84e1de45600b7da97 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 24 May 2024 09:13:02 -0700 Subject: [PATCH 03/11] schemav3 --- .../execution/streaming/MapStateImpl.scala | 1 - .../streaming/TransformWithStateExec.scala | 39 +++++++--- .../state/OperatorStateMetadata.scala | 10 +-- .../streaming/state/SchemaHelper.scala | 73 ++++++++++++++----- .../streaming/state/StateStore.scala | 2 +- .../streaming/state/MemoryStateStore.scala | 4 + .../streaming/TransformWithStateSuite.scala | 8 +- 7 files changed, 99 insertions(+), 38 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index 99dfc763e9fd4..153400e51b5cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -22,7 +22,6 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} import org.apache.spark.sql.streaming.MapState -import org.apache.spark.sql.types.{BinaryType, StructType} class MapStateImpl[K, V]( store: StateStore, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 6e6f56e3169c8..1ffdb327c52a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -39,9 +39,10 @@ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.StateStoreAwareZipPartitionsHelper import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaV3Reader, SchemaV3Writer} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ -import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils} +import org.apache.spark.util.{CollectionAccumulator, CompletionIterator, SerializableConfiguration, Utils} /** * Physical operator for executing `TransformWithState` @@ -86,10 +87,10 @@ case class TransformWithStateExec( override def shortName: String = "transformWithStateExec" - val operatorProperties: OperatorProperties = - OperatorProperties.create( + private val operatorPropertiesFromExecutor: OperatorPropertiesFromExecutor = + OperatorPropertiesFromExecutor.create( sparkContext, - "colFamilyMetadata" + "operatorPropsFromExecutor" ) /** Metadata of this stateful operator and its states stores. */ @@ -101,7 +102,7 @@ case class TransformWithStateExec( val operatorPropertiesJson: JValue = ("timeMode" -> JString(timeMode.toString)) ~ ("outputMode" -> JString(outputMode.toString)) ~ - ("stateVariables" -> operatorProperties.value.get("stateVariables")) + ("stateVariables" -> operatorPropertiesFromExecutor.value.get("stateVariables")) val json = compact(render(operatorPropertiesJson)) OperatorStateMetadataV2(operatorInfo, stateStoreInfo, json) @@ -114,13 +115,26 @@ case class TransformWithStateExec( * to write the metadata of the operator to the checkpoint file. */ override def writeOperatorStateMetadata(): Unit = { + val stateCheckpointPath = new Path(stateInfo.get.checkpointLocation, + getStateInfo.operatorId.toString) + val metadata = operatorStateMetadata() + val hadoopConf = session.sqlContext.sessionState.newHadoopConf() val metadataWriter = new OperatorStateMetadataWriter( - new Path(stateInfo.get.checkpointLocation, - getStateInfo.operatorId.toString), - session.sqlContext.sessionState.newHadoopConf() + stateCheckpointPath, + hadoopConf ) metadataWriter.write(metadata) + + val schemaV3Writer = new SchemaV3Writer( + stateCheckpointPath, + hadoopConf + ) + + val jValue = operatorPropertiesFromExecutor.value.get("columnFamilyMetadatas") + val json = compact(render(jValue)) + schemaV3Writer.writeSchema(json) + super.writeOperatorStateMetadata() } @@ -353,7 +367,7 @@ case class TransformWithStateExec( store.abort() } } - operatorProperties.add(Map + operatorPropertiesFromExecutor.add(Map ("stateVariables" -> JArray(processorHandle.stateVariables. asScala.map(_.jsonValue).toList))) setStoreMetrics(store) @@ -429,6 +443,13 @@ case class TransformWithStateExec( } } else { if (isStreaming) { + val stateCheckpointPath = new Path(stateInfo.get.checkpointLocation, + getStateInfo.operatorId.toString) + val hadoopConf = session.sqlContext.sessionState.newHadoopConf() + val reader = new SchemaV3Reader(stateCheckpointPath, hadoopConf) + reader.read.foreach { columnFamilyName => + print(s"### columnFamilyName: $columnFamilyName") + } child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, KEY_ROW_SCHEMA, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala index e210e669b6c10..05012b4724d46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/OperatorStateMetadata.scala @@ -79,7 +79,7 @@ case class OperatorStateMetadataV1( * available on the driver at the time of planning, and will only be known from * the executor side. */ -class OperatorProperties(initValue: Map[String, JValue] = Map.empty) +class OperatorPropertiesFromExecutor(initValue: Map[String, JValue] = Map.empty) extends AccumulatorV2[Map[String, JValue], Map[String, JValue]] { private var _value: Map[String, JValue] = initValue @@ -87,7 +87,7 @@ class OperatorProperties(initValue: Map[String, JValue] = Map.empty) override def isZero: Boolean = _value.isEmpty override def copy(): AccumulatorV2[Map[String, JValue], Map[String, JValue]] = { - val newAcc = new OperatorProperties + val newAcc = new OperatorPropertiesFromExecutor newAcc._value = _value newAcc } @@ -103,12 +103,12 @@ class OperatorProperties(initValue: Map[String, JValue] = Map.empty) override def value: Map[String, JValue] = _value } -object OperatorProperties { +object OperatorPropertiesFromExecutor { def create( sc: SparkContext, name: String, - initValue: Map[String, JValue] = Map.empty): OperatorProperties = { - val acc = new OperatorProperties(initValue) + initValue: Map[String, JValue] = Map.empty): OperatorPropertiesFromExecutor = { + val acc = new OperatorPropertiesFromExecutor(initValue) acc.register(sc, name = Some(name)) acc } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index ad30f186e636a..f45140eb63778 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql.execution.streaming.state import java.io.StringReader -import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream} -import org.json4s.JsonAST +import org.apache.hadoop.fs.{FSDataInputStream, FSDataOutputStream, Path} +import org.json4s.{DefaultFormats, JsonAST} import org.json4s.JsonAST._ import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods.{compact, render} -import org.apache.spark.sql.execution.streaming.MetadataVersionUtil +import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -34,7 +35,7 @@ class ColumnFamilyMetadataV1( val keySchema: StructType, val valueSchema: StructType, val keyStateEncoderSpec: KeyStateEncoderSpec, - val multipleValuesPerKey: Boolean) { + val multipleValuesPerKey: Boolean) extends Serializable { def jsonValue: JsonAST.JObject = { ("columnFamilyName" -> JString(columnFamilyName)) ~ ("keySchema" -> keySchema.jsonValue) ~ @@ -91,13 +92,33 @@ object SchemaHelper { } } - class SchemaV3Reader { - def read(inputStream: FSDataInputStream): ColumnFamilyMetadataV1 = { + class SchemaV3Reader( + stateCheckpointPath: Path, + hadoopConf: org.apache.hadoop.conf.Configuration) { + + private val schemaFilePath = SchemaV3Writer.getSchemaFilePath(stateCheckpointPath) + + private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) + def read: List[String] = { + if (!fm.exists(schemaFilePath)) { + return List.empty + } val buf = new StringBuilder - val numMetadataChunks = inputStream.readInt() - (0 until numMetadataChunks).foreach(_ => buf.append(inputStream.readUTF())) - val colFamilyMetadataStr = buf.toString() - ColumnFamilyMetadataV1.fromString(colFamilyMetadataStr) + val inputStream = fm.open(schemaFilePath) + val numKeyChunks = inputStream.readInt() + (0 until numKeyChunks).foreach(_ => buf.append(inputStream.readUTF())) + val json = buf.toString() + val parsedJson = JsonMethods.parse(json) + + implicit val formats = DefaultFormats + val deserializedList: List[Any] = parsedJson.extract[List[Any]] + assert(deserializedList.isInstanceOf[List[_]], + s"Expected List but got ${deserializedList.getClass}") + val columnFamilyMetadatas = deserializedList.asInstanceOf[List[Map[String, Any]]] + // Extract each JValue to StateVariableInfo + columnFamilyMetadatas.map { columnFamilyMetadata => + columnFamilyMetadata("columnFamilyName").asInstanceOf[String] + } } } @@ -178,33 +199,49 @@ object SchemaHelper { } } + object SchemaV3Writer { + def getSchemaFilePath(stateCheckpointPath: Path): Path = { + new Path(new Path(stateCheckpointPath, "_metadata"), "schema") + } + } /** * Schema writer for schema version 3. Because this writer writes out ColFamilyMetadatas * instead of key and value schemas, it is not compatible with the SchemaWriter interface. */ - class SchemaV3Writer { + class SchemaV3Writer( + stateCheckpointPath: Path, + hadoopConf: org.apache.hadoop.conf.Configuration) { val version: Int = 3 + private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) + private val schemaFilePath = SchemaV3Writer.getSchemaFilePath(stateCheckpointPath) + // 2^16 - 1 bytes final val MAX_UTF_CHUNK_SIZE = 65535 - def writeSchema( - metadatas: Array[ColumnFamilyMetadataV1], - outputStream: FSDataOutputStream): Unit = { + def writeSchema(metadatasJson: String): Unit = { val buf = new Array[Char](MAX_UTF_CHUNK_SIZE) + if (fm.exists(schemaFilePath)) return + + fm.mkdirs(schemaFilePath.getParent) + val outputStream = fm.createAtomic(schemaFilePath, overwriteIfPossible = false) // DataOutputStream.writeUTF can't write a string at once // if the size exceeds 65535 (2^16 - 1) bytes. // Each metadata consists of multiple chunks in schema version 3. - metadatas.foreach{ metadata => - val metadataJson = metadata.json - val numMetadataChunks = (metadataJson.length - 1) / MAX_UTF_CHUNK_SIZE + 1 - val metadataStringReader = new StringReader(metadataJson) + try { + val numMetadataChunks = (metadatasJson.length - 1) / MAX_UTF_CHUNK_SIZE + 1 + val metadataStringReader = new StringReader(metadatasJson) outputStream.writeInt(numMetadataChunks) (0 until numMetadataChunks).foreach { _ => val numRead = metadataStringReader.read(buf, 0, MAX_UTF_CHUNK_SIZE) outputStream.writeUTF(new String(buf, 0, numRead)) } + outputStream.close() + } catch { + case e: Throwable => + outputStream.cancel() + throw e } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ac9f609976b62..45e8e6b099fd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -138,7 +138,7 @@ trait StateStore extends ReadStateStore { def createColFamilyIfAbsent( colFamilyMetadata: ColumnFamilyMetadataV1 - ) + ): Unit /** * Put a new non-null value for a non-null key. Implementations must be aware that the UnsafeRows diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index 6a476635a6dbe..aa81d7a1594d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -74,4 +74,8 @@ class MemoryStateStore extends StateStore() { override def valuesIterator(key: UnsafeRow, colFamilyName: String): Iterator[UnsafeRow] = { throw new UnsupportedOperationException("Doesn't support multiple values per key") } + + override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilyMetadataV1): Unit = { + throw StateStoreErrors.removingColumnFamiliesNotSupported("MemoryStateStoreProvider") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 862b02af354e6..d7cbcf505ca18 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -421,13 +421,13 @@ class TransformWithStateSuite extends StateStoreMetricsTest val propsString = df.select("operatorProperties"). collect().head.getString(0) - val map = TransformWithStateExec. + val operatorProperties = TransformWithStateExec. deserializeOperatorProperties(propsString) - assert(map("timeMode") === "ProcessingTime") - assert(map("outputMode") === "Update") + assert(operatorProperties("timeMode") === "ProcessingTime") + assert(operatorProperties("outputMode") === "Update") val stateVariableInfos = StateVariableInfo.fromJson( - map("stateVariables")) + operatorProperties("stateVariables")) assert(stateVariableInfos.size === 1) val stateVariableInfo = stateVariableInfos.head assert(stateVariableInfo.stateName === "countState") From 8feb83a0a6687730fbd7da0ee1e847de14529d38 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 24 May 2024 11:39:53 -0700 Subject: [PATCH 04/11] able to serialize, deserialize schemas --- .../streaming/TransformWithStateExec.scala | 5 +--- .../streaming/state/SchemaHelper.scala | 27 ++++++++++++++----- .../streaming/state/StateStore.scala | 26 +++++++++++++++--- 3 files changed, 45 insertions(+), 13 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 1ffdb327c52a6..617dca0f42e1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.streaming.state._ import org.apache.spark.sql.execution.streaming.state.SchemaHelper.{SchemaV3Reader, SchemaV3Writer} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming._ -import org.apache.spark.util.{CollectionAccumulator, CompletionIterator, SerializableConfiguration, Utils} +import org.apache.spark.util.{CompletionIterator, SerializableConfiguration, Utils} /** * Physical operator for executing `TransformWithState` @@ -447,9 +447,6 @@ case class TransformWithStateExec( getStateInfo.operatorId.toString) val hadoopConf = session.sqlContext.sessionState.newHadoopConf() val reader = new SchemaV3Reader(stateCheckpointPath, hadoopConf) - reader.read.foreach { columnFamilyName => - print(s"### columnFamilyName: $columnFamilyName") - } child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, KEY_ROW_SCHEMA, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index f45140eb63778..9c7f4e3aae656 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -38,8 +38,8 @@ class ColumnFamilyMetadataV1( val multipleValuesPerKey: Boolean) extends Serializable { def jsonValue: JsonAST.JObject = { ("columnFamilyName" -> JString(columnFamilyName)) ~ - ("keySchema" -> keySchema.jsonValue) ~ - ("valueSchema" -> valueSchema.jsonValue) ~ + ("keySchema" -> keySchema.json) ~ + ("valueSchema" -> valueSchema.json) ~ ("keyStateEncoderSpec" -> keyStateEncoderSpec.jsonValue) ~ ("multipleValuesPerKey" -> JBool(multipleValuesPerKey)) } @@ -49,6 +49,23 @@ class ColumnFamilyMetadataV1( } } +object ColumnFamilyMetadataV1 { + def fromJson(json: List[Map[String, Any]]): List[ColumnFamilyMetadataV1] = { + assert(json.isInstanceOf[List[_]]) + + json.map { colFamilyMap => + new ColumnFamilyMetadataV1( + colFamilyMap("columnFamilyName").asInstanceOf[String], + StructType.fromString(colFamilyMap("keySchema").asInstanceOf[String]), + StructType.fromString(colFamilyMap("valueSchema").asInstanceOf[String]), + KeyStateEncoderSpec.fromJson(colFamilyMap("keyStateEncoderSpec") + .asInstanceOf[Map[String, Any]]), + colFamilyMap("multipleValuesPerKey").asInstanceOf[Boolean] + ) + } + } +} + /** * Helper classes for reading/writing state schema. */ @@ -99,7 +116,7 @@ object SchemaHelper { private val schemaFilePath = SchemaV3Writer.getSchemaFilePath(stateCheckpointPath) private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) - def read: List[String] = { + def read: List[ColumnFamilyMetadataV1] = { if (!fm.exists(schemaFilePath)) { return List.empty } @@ -116,9 +133,7 @@ object SchemaHelper { s"Expected List but got ${deserializedList.getClass}") val columnFamilyMetadatas = deserializedList.asInstanceOf[List[Map[String, Any]]] // Extract each JValue to StateVariableInfo - columnFamilyMetadatas.map { columnFamilyMetadata => - columnFamilyMetadata("columnFamilyName").asInstanceOf[String] - } + ColumnFamilyMetadataV1.fromJson(columnFamilyMetadatas) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index 45e8e6b099fd2..f50d1019887eb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -31,6 +31,7 @@ import org.apache.hadoop.fs.Path import org.json4s.{JInt, JsonAST, JString} import org.json4s.JsonAST.JObject import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} import org.apache.spark.{SparkContext, SparkEnv, SparkUnsupportedOperationException} import org.apache.spark.internal.{Logging, LogKeys, MDC} @@ -298,12 +299,31 @@ class InvalidUnsafeRowException(error: String) sealed trait KeyStateEncoderSpec { def jsonValue: JsonAST.JObject + def json: String = compact(render(jsonValue)) +} + +object KeyStateEncoderSpec { + def fromJson(m: Map[String, Any]): KeyStateEncoderSpec = { + // match on type + val keySchema = StructType.fromString(m("keySchema").asInstanceOf[String]) + m("keyStateEncoderType").asInstanceOf[String] match { + case "NoPrefixKeyStateEncoderSpec" => + NoPrefixKeyStateEncoderSpec(keySchema) + case "RangeKeyScanStateEncoderSpec" => + val orderingOrdinals = m("orderingOrdinals"). + asInstanceOf[List[_]].map(_.asInstanceOf[Int]) + RangeKeyScanStateEncoderSpec(keySchema, orderingOrdinals) + case "PrefixKeyScanStateEncoderSpec" => + val numColsPrefixKey = m("numColsPrefixKey").asInstanceOf[Int] + PrefixKeyScanStateEncoderSpec(keySchema, numColsPrefixKey) + } + } } case class NoPrefixKeyStateEncoderSpec(keySchema: StructType) extends KeyStateEncoderSpec { override def jsonValue: JsonAST.JObject = { ("keyStateEncoderType" -> JString("NoPrefixKeyStateEncoderSpec")) ~ - ("keySchema" -> keySchema.jsonValue) + ("keySchema" -> JString(keySchema.json)) } } @@ -315,7 +335,7 @@ case class PrefixKeyScanStateEncoderSpec( } override def jsonValue: JsonAST.JObject = { ("keyStateEncoderType" -> JString("PrefixKeyScanStateEncoderSpec")) ~ - ("keySchema" -> keySchema.jsonValue) ~ + ("keySchema" -> JString(keySchema.json)) ~ ("numColsPrefixKey" -> JInt(numColsPrefixKey)) } } @@ -330,7 +350,7 @@ case class RangeKeyScanStateEncoderSpec( override def jsonValue: JObject = { ("keyStateEncoderType" -> JString("RangeKeyScanStateEncoderSpec")) ~ - ("keySchema" -> keySchema.jsonValue) ~ + ("keySchema" -> JString(keySchema.json)) ~ ("orderingOrdinals" -> orderingOrdinals.map(JInt(_))) } } From 6a68d6b1f0318116825339490e4830247375143e Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 24 May 2024 13:35:00 -0700 Subject: [PATCH 05/11] adding colfamilymetadata trait --- .../sql/execution/streaming/ListStateImpl.scala | 4 ++-- .../streaming/ListStateImplWithTTL.scala | 4 ++-- .../sql/execution/streaming/MapStateImpl.scala | 4 ++-- .../streaming/MapStateImplWithTTL.scala | 4 ++-- .../streaming/StatefulProcessorHandleImpl.scala | 16 ++++++++-------- .../sql/execution/streaming/ValueStateImpl.scala | 4 ++-- .../streaming/ValueStateImplWithTTL.scala | 4 ++-- .../execution/streaming/state/SchemaHelper.scala | 14 ++++++++++---- 8 files changed, 30 insertions(+), 24 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index fb8a18db860cd..97a9b9f99b9d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -44,9 +44,9 @@ class ListStateImpl[S]( private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilyMetadataV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) - store.createColFamilyIfAbsent(columnFamilyMetadataV1) + store.createColFamilyIfAbsent(columnFamilyMetadata) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index 350b70b28f7aa..0a38a4287bd31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -52,13 +52,13 @@ class ListStateImplWithTTL[S]( private lazy val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilyMetadataV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), true) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilyMetadataV1) + store.createColFamilyIfAbsent(columnFamilyMetadata) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index 153400e51b5cd..49b76d4aa50a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -34,11 +34,11 @@ class MapStateImpl[K, V]( private val stateTypesEncoder = new CompositeKeyStateEncoder( keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName) - val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilyMetadataV1( stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) - store.createColFamilyIfAbsent(columnFamilyMetadataV1) + store.createColFamilyIfAbsent(columnFamilyMetadata) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala index fe82e37035743..54af06641edfd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala @@ -55,13 +55,13 @@ class MapStateImplWithTTL[K, V]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilyMetadataV1( stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilyMetadataV1) + store.createColFamilyIfAbsent(columnFamilyMetadata) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 1b32aa066d69b..27d74ecb2c9d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -97,8 +97,8 @@ class StatefulProcessorHandleImpl( private[sql] val stateVariables: util.List[StateVariableInfo] = new util.ArrayList[StateVariableInfo]() - var columnFamilyMetadatas: mutable.ListBuffer[ColumnFamilyMetadataV1] = - mutable.ListBuffer.empty[ColumnFamilyMetadataV1] + var columnFamilyMetadatas: mutable.ListBuffer[ColumnFamilyMetadata] = + mutable.ListBuffer.empty[ColumnFamilyMetadata] private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" @@ -138,7 +138,7 @@ class StatefulProcessorHandleImpl( stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) incrementMetric("numValueStateVars") val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) - columnFamilyMetadatas.addOne(resultState.columnFamilyMetadataV1) + columnFamilyMetadatas.addOne(resultState.columnFamilyMetadata) resultState } @@ -155,7 +155,7 @@ class StatefulProcessorHandleImpl( keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numValueStateWithTTLVars") ttlStates.add(valueStateWithTTL) - columnFamilyMetadatas.addOne(valueStateWithTTL.columnFamilyMetadataV1) + columnFamilyMetadatas.addOne(valueStateWithTTL.columnFamilyMetadata) valueStateWithTTL } @@ -253,7 +253,7 @@ class StatefulProcessorHandleImpl( stateVariables.add(new StateVariableInfo(stateName, ListState, false)) incrementMetric("numListStateVars") val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) - columnFamilyMetadatas.addOne(resultState.columnFamilyMetadataV1) + columnFamilyMetadatas.addOne(resultState.columnFamilyMetadata) resultState } @@ -286,7 +286,7 @@ class StatefulProcessorHandleImpl( keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numListStateWithTTLVars") ttlStates.add(listStateWithTTL) - columnFamilyMetadatas.addOne(listStateWithTTL.columnFamilyMetadataV1) + columnFamilyMetadatas.addOne(listStateWithTTL.columnFamilyMetadata) listStateWithTTL } @@ -299,7 +299,7 @@ class StatefulProcessorHandleImpl( stateVariables.add(new StateVariableInfo(stateName, MapState, false)) incrementMetric("numMapStateVars") val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) - columnFamilyMetadatas.addOne(resultState.columnFamilyMetadataV1) + columnFamilyMetadatas.addOne(resultState.columnFamilyMetadata) resultState } @@ -317,7 +317,7 @@ class StatefulProcessorHandleImpl( valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numMapStateWithTTLVars") ttlStates.add(mapStateWithTTL) - columnFamilyMetadatas.addOne(mapStateWithTTL.columnFamilyMetadataV1) + columnFamilyMetadatas.addOne(mapStateWithTTL.columnFamilyMetadata) mapStateWithTTL } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 7a02a340b3cd9..ff7c20fee3837 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -42,12 +42,12 @@ class ValueStateImpl[S]( private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilyMetadataV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilyMetadataV1) + store.createColFamilyIfAbsent(columnFamilyMetadata) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index cc33e7fae694c..abb6cce31e727 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -49,13 +49,13 @@ class ValueStateImplWithTTL[S]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilyMetadataV1 = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilyMetadataV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilyMetadataV1) + store.createColFamilyIfAbsent(columnFamilyMetadata) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index 9c7f4e3aae656..be8d8ea8b862b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -30,12 +30,18 @@ import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, Metadata import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils -class ColumnFamilyMetadataV1( +sealed trait ColumnFamilyMetadata extends Serializable { + def jsonValue: JsonAST.JObject + + def json: String +} + +case class ColumnFamilyMetadataV1( val columnFamilyName: String, val keySchema: StructType, val valueSchema: StructType, val keyStateEncoderSpec: KeyStateEncoderSpec, - val multipleValuesPerKey: Boolean) extends Serializable { + val multipleValuesPerKey: Boolean) extends ColumnFamilyMetadata { def jsonValue: JsonAST.JObject = { ("columnFamilyName" -> JString(columnFamilyName)) ~ ("keySchema" -> keySchema.json) ~ @@ -50,7 +56,7 @@ class ColumnFamilyMetadataV1( } object ColumnFamilyMetadataV1 { - def fromJson(json: List[Map[String, Any]]): List[ColumnFamilyMetadataV1] = { + def fromJson(json: List[Map[String, Any]]): List[ColumnFamilyMetadata] = { assert(json.isInstanceOf[List[_]]) json.map { colFamilyMap => @@ -116,7 +122,7 @@ object SchemaHelper { private val schemaFilePath = SchemaV3Writer.getSchemaFilePath(stateCheckpointPath) private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) - def read: List[ColumnFamilyMetadataV1] = { + def read: List[ColumnFamilyMetadata] = { if (!fm.exists(schemaFilePath)) { return List.empty } From 8f759431ab3b1bc5d8852ecd814655de0c2bd43c Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 24 May 2024 13:45:24 -0700 Subject: [PATCH 06/11] using arraylist --- .../streaming/StatefulProcessorHandleImpl.scala | 16 ++++++++-------- .../streaming/TransformWithStateExec.scala | 16 +++++++++++++--- 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 27d74ecb2c9d8..9b169fc4b3344 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -97,8 +97,8 @@ class StatefulProcessorHandleImpl( private[sql] val stateVariables: util.List[StateVariableInfo] = new util.ArrayList[StateVariableInfo]() - var columnFamilyMetadatas: mutable.ListBuffer[ColumnFamilyMetadata] = - mutable.ListBuffer.empty[ColumnFamilyMetadata] + private[sql] val columnFamilyMetadatas: util.List[ColumnFamilyMetadata] = + new util.ArrayList[ColumnFamilyMetadata]() private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" @@ -138,7 +138,7 @@ class StatefulProcessorHandleImpl( stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) incrementMetric("numValueStateVars") val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) - columnFamilyMetadatas.addOne(resultState.columnFamilyMetadata) + columnFamilyMetadatas.add(resultState.columnFamilyMetadata) resultState } @@ -155,7 +155,7 @@ class StatefulProcessorHandleImpl( keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numValueStateWithTTLVars") ttlStates.add(valueStateWithTTL) - columnFamilyMetadatas.addOne(valueStateWithTTL.columnFamilyMetadata) + columnFamilyMetadatas.add(valueStateWithTTL.columnFamilyMetadata) valueStateWithTTL } @@ -253,7 +253,7 @@ class StatefulProcessorHandleImpl( stateVariables.add(new StateVariableInfo(stateName, ListState, false)) incrementMetric("numListStateVars") val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) - columnFamilyMetadatas.addOne(resultState.columnFamilyMetadata) + columnFamilyMetadatas.add(resultState.columnFamilyMetadata) resultState } @@ -286,7 +286,7 @@ class StatefulProcessorHandleImpl( keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numListStateWithTTLVars") ttlStates.add(listStateWithTTL) - columnFamilyMetadatas.addOne(listStateWithTTL.columnFamilyMetadata) + columnFamilyMetadatas.add(listStateWithTTL.columnFamilyMetadata) listStateWithTTL } @@ -299,7 +299,7 @@ class StatefulProcessorHandleImpl( stateVariables.add(new StateVariableInfo(stateName, MapState, false)) incrementMetric("numMapStateVars") val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) - columnFamilyMetadatas.addOne(resultState.columnFamilyMetadata) + columnFamilyMetadatas.add(resultState.columnFamilyMetadata) resultState } @@ -317,7 +317,7 @@ class StatefulProcessorHandleImpl( valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numMapStateWithTTLVars") ttlStates.add(mapStateWithTTL) - columnFamilyMetadatas.addOne(mapStateWithTTL.columnFamilyMetadata) + columnFamilyMetadatas.add(mapStateWithTTL.columnFamilyMetadata) mapStateWithTTL } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 617dca0f42e1b..069ca22b92a3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -367,9 +367,19 @@ case class TransformWithStateExec( store.abort() } } - operatorPropertiesFromExecutor.add(Map - ("stateVariables" -> JArray(processorHandle.stateVariables. - asScala.map(_.jsonValue).toList))) + + // only write this information out for partition 0 + if (store.id.partitionId == 0) { + operatorPropertiesFromExecutor.add(Map + ("stateVariables" -> JArray(processorHandle.stateVariables. + asScala.map(_.jsonValue).toList))) + + operatorPropertiesFromExecutor.add(Map + ("columnFamilyMetadatas" -> + JArray(processorHandle.columnFamilyMetadatas. + asScala.map(_.jsonValue).toList))) + } + setStoreMetrics(store) setOperatorMetrics() statefulProcessor.close() From 970cc135149f32ce20569166aed21e3ae5da1b63 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Fri, 24 May 2024 16:14:16 -0700 Subject: [PATCH 07/11] moving stateVariables.add --- .../streaming/StatefulProcessorHandleImpl.scala | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 9b169fc4b3344..5f70edf310342 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -135,9 +135,9 @@ class StatefulProcessorHandleImpl( stateName: String, valEncoder: Encoder[T]): ValueState[T] = { verifyStateVarOperations("get_value_state") - stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) incrementMetric("numValueStateVars") val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) + stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) columnFamilyMetadatas.add(resultState.columnFamilyMetadata) resultState } @@ -147,7 +147,6 @@ class StatefulProcessorHandleImpl( valEncoder: Encoder[T], ttlConfig: TTLConfig): ValueState[T] = { verifyStateVarOperations("get_value_state") - stateVariables.add(new StateVariableInfo(stateName, ValueState, true)) validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) @@ -155,6 +154,7 @@ class StatefulProcessorHandleImpl( keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numValueStateWithTTLVars") ttlStates.add(valueStateWithTTL) + stateVariables.add(new StateVariableInfo(stateName, ValueState, true)) columnFamilyMetadatas.add(valueStateWithTTL.columnFamilyMetadata) valueStateWithTTL } @@ -250,9 +250,9 @@ class StatefulProcessorHandleImpl( override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { verifyStateVarOperations("get_list_state") - stateVariables.add(new StateVariableInfo(stateName, ListState, false)) incrementMetric("numListStateVars") val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) + stateVariables.add(new StateVariableInfo(stateName, ListState, false)) columnFamilyMetadatas.add(resultState.columnFamilyMetadata) resultState } @@ -278,7 +278,6 @@ class StatefulProcessorHandleImpl( ttlConfig: TTLConfig): ListState[T] = { verifyStateVarOperations("get_list_state") - stateVariables.add(new StateVariableInfo(stateName, ListState, true)) validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) @@ -286,6 +285,7 @@ class StatefulProcessorHandleImpl( keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numListStateWithTTLVars") ttlStates.add(listStateWithTTL) + stateVariables.add(new StateVariableInfo(stateName, ListState, true)) columnFamilyMetadatas.add(listStateWithTTL.columnFamilyMetadata) listStateWithTTL @@ -296,9 +296,9 @@ class StatefulProcessorHandleImpl( userKeyEnc: Encoder[K], valEncoder: Encoder[V]): MapState[K, V] = { verifyStateVarOperations("get_map_state") - stateVariables.add(new StateVariableInfo(stateName, MapState, false)) incrementMetric("numMapStateVars") val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) + stateVariables.add(new StateVariableInfo(stateName, MapState, false)) columnFamilyMetadatas.add(resultState.columnFamilyMetadata) resultState } @@ -309,7 +309,6 @@ class StatefulProcessorHandleImpl( valEncoder: Encoder[V], ttlConfig: TTLConfig): MapState[K, V] = { verifyStateVarOperations("get_map_state") - stateVariables.add(new StateVariableInfo(stateName, MapState, true)) validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) @@ -317,6 +316,7 @@ class StatefulProcessorHandleImpl( valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numMapStateWithTTLVars") ttlStates.add(mapStateWithTTL) + stateVariables.add(new StateVariableInfo(stateName, MapState, true)) columnFamilyMetadatas.add(mapStateWithTTL.columnFamilyMetadata) mapStateWithTTL From abf8c04b2e57fd85e5f1f16d4c2cb187617225e8 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 27 May 2024 18:32:16 -0700 Subject: [PATCH 08/11] renaming to ColumnFamilySchema --- .../execution/streaming/ListStateImpl.scala | 4 +- .../streaming/ListStateImplWithTTL.scala | 4 +- .../execution/streaming/MapStateImpl.scala | 4 +- .../streaming/MapStateImplWithTTL.scala | 4 +- .../StatefulProcessorHandleImpl.scala | 8 +- .../streaming/TransformWithStateExec.scala | 24 ++++-- .../execution/streaming/ValueStateImpl.scala | 4 +- .../streaming/ValueStateImplWithTTL.scala | 4 +- .../state/HDFSBackedStateStoreProvider.scala | 2 +- .../state/RocksDBStateStoreProvider.scala | 2 +- .../streaming/state/SchemaHelper.scala | 76 ++++++++++++++++--- .../streaming/state/StateStore.scala | 2 +- .../streaming/state/MemoryStateStore.scala | 2 +- 13 files changed, 107 insertions(+), 33 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 97a9b9f99b9d1..9cfaa4a454b92 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.ListState /** @@ -44,7 +44,7 @@ class ListStateImpl[S]( private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - val columnFamilyMetadata = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilySchemaV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) store.createColFamilyIfAbsent(columnFamilyMetadata) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index 0a38a4287bd31..8e69ea607546a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{ListState, TTLConfig} import org.apache.spark.util.NextIterator @@ -52,7 +52,7 @@ class ListStateImplWithTTL[S]( private lazy val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilyMetadata = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilySchemaV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), true) initialize() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index 49b76d4aa50a6..2587f111a2077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair} import org.apache.spark.sql.streaming.MapState class MapStateImpl[K, V]( @@ -34,7 +34,7 @@ class MapStateImpl[K, V]( private val stateTypesEncoder = new CompositeKeyStateEncoder( keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName) - val columnFamilyMetadata = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilySchemaV1( stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala index 54af06641edfd..510c6870fc822 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors} import org.apache.spark.sql.streaming.{MapState, TTLConfig} import org.apache.spark.util.NextIterator @@ -55,7 +55,7 @@ class MapStateImplWithTTL[K, V]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilyMetadata = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilySchemaV1( stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) initialize() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 5f70edf310342..453e4bd25f0a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -84,7 +84,8 @@ class StatefulProcessorHandleImpl( timeMode: TimeMode, isStreaming: Boolean = true, batchTimestampMs: Option[Long] = None, - metrics: Map[String, SQLMetric] = Map.empty) + metrics: Map[String, SQLMetric] = Map.empty, + existingColFamilies: Map[String, ColumnFamilyAccumulator] = Map.empty) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ @@ -97,8 +98,8 @@ class StatefulProcessorHandleImpl( private[sql] val stateVariables: util.List[StateVariableInfo] = new util.ArrayList[StateVariableInfo]() - private[sql] val columnFamilyMetadatas: util.List[ColumnFamilyMetadata] = - new util.ArrayList[ColumnFamilyMetadata]() + private[sql] val columnFamilyMetadatas: util.List[ColumnFamilySchema] = + new util.ArrayList[ColumnFamilySchema]() private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" @@ -168,6 +169,7 @@ class StatefulProcessorHandleImpl( throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState(operationType, currState.toString) } + } private def verifyTimerOperations(operationType: String): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 069ca22b92a3f..6a6b4b1c9cf3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -93,6 +93,22 @@ case class TransformWithStateExec( "operatorPropsFromExecutor" ) + private lazy val colFamilyAccumulators: Map[String, ColumnFamilyAccumulator] = + initializeColFamilyAccumulators() + + private def initializeColFamilyAccumulators(): Map[String, ColumnFamilyAccumulator] = { + val stateCheckpointPath = new Path(stateInfo.get.checkpointLocation, + getStateInfo.operatorId.toString) + val hadoopConf = session.sqlContext.sessionState.newHadoopConf() + + val reader = new SchemaV3Reader(stateCheckpointPath, hadoopConf) + + reader.read.map { colFamilyMetadata => + val acc = ColumnFamilyAccumulator.create(colFamilyMetadata, sparkContext) + colFamilyMetadata.asInstanceOf[ColumnFamilySchemaV1].columnFamilyName -> acc + }.toMap + } + /** Metadata of this stateful operator and its states stores. */ override def operatorStateMetadata(): OperatorStateMetadata = { val info = getStateInfo @@ -414,6 +430,7 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver + colFamilyAccumulators validateTimeMode() @@ -453,10 +470,7 @@ case class TransformWithStateExec( } } else { if (isStreaming) { - val stateCheckpointPath = new Path(stateInfo.get.checkpointLocation, - getStateInfo.operatorId.toString) - val hadoopConf = session.sqlContext.sessionState.newHadoopConf() - val reader = new SchemaV3Reader(stateCheckpointPath, hadoopConf) + child.execute().mapPartitionsWithStateStore[InternalRow]( getStateInfo, KEY_ROW_SCHEMA, @@ -535,7 +549,7 @@ case class TransformWithStateExec( CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( store, getStateInfo.queryRunId, keyEncoder, timeMode, - isStreaming, batchTimestampMs, metrics) + isStreaming, batchTimestampMs, metrics, colFamilyAccumulators) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) statefulProcessor.init(outputMode, timeMode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index ff7c20fee3837..3922f41b92d60 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA} -import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.ValueState /** @@ -42,7 +42,7 @@ class ValueStateImpl[S]( private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - val columnFamilyMetadata = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilySchemaV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index abb6cce31e727..a26071685b8bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL} -import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore} +import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore} import org.apache.spark.sql.streaming.{TTLConfig, ValueState} /** @@ -49,7 +49,7 @@ class ValueStateImplWithTTL[S]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilyMetadata = new ColumnFamilyMetadataV1( + val columnFamilyMetadata = new ColumnFamilySchemaV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 9a298ae6d5aab..2887ea6825d5c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -246,7 +246,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with throw StateStoreErrors.unsupportedOperationException("merge", providerName) } - override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilyMetadataV1): Unit = { + override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilySchemaV1): Unit = { throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala index 82bf20365e52c..05ac246bc7552 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/RocksDBStateStoreProvider.scala @@ -265,7 +265,7 @@ private[sql] class RocksDBStateStoreProvider result } - override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilyMetadataV1): Unit = { + override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilySchemaV1): Unit = { createColFamilyIfAbsent( colFamilyMetadata.columnFamilyName, colFamilyMetadata.keySchema, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index be8d8ea8b862b..340d5c23e6cce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -26,22 +26,23 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods.{compact, render} +import org.apache.spark.SparkContext import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils +import org.apache.spark.util.{AccumulatorV2, Utils} -sealed trait ColumnFamilyMetadata extends Serializable { +sealed trait ColumnFamilySchema extends Serializable { def jsonValue: JsonAST.JObject def json: String } -case class ColumnFamilyMetadataV1( +case class ColumnFamilySchemaV1( val columnFamilyName: String, val keySchema: StructType, val valueSchema: StructType, val keyStateEncoderSpec: KeyStateEncoderSpec, - val multipleValuesPerKey: Boolean) extends ColumnFamilyMetadata { + val multipleValuesPerKey: Boolean) extends ColumnFamilySchema { def jsonValue: JsonAST.JObject = { ("columnFamilyName" -> JString(columnFamilyName)) ~ ("keySchema" -> keySchema.json) ~ @@ -55,12 +56,69 @@ case class ColumnFamilyMetadataV1( } } -object ColumnFamilyMetadataV1 { - def fromJson(json: List[Map[String, Any]]): List[ColumnFamilyMetadata] = { +class ColumnFamilyAccumulator( + columnFamilyMetadata: ColumnFamilySchema) extends + AccumulatorV2[ColumnFamilySchema, ColumnFamilySchema] { + + private var _value: ColumnFamilySchema = columnFamilyMetadata + /** + * Returns if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero + * value; for a list accumulator, Nil is zero value. + */ + override def isZero: Boolean = _value == null + + /** + * Creates a new copy of this accumulator. + */ + override def copy(): AccumulatorV2[ColumnFamilySchema, ColumnFamilySchema] = { + new ColumnFamilyAccumulator(_value) + } + + /** + * Resets this accumulator, which is zero value. i.e. call `isZero` must + * return true. + */ + override def reset(): Unit = { + _value = null + } + + /** + * Takes the inputs and accumulates. + */ + override def add(v: ColumnFamilySchema): Unit = { + _value = v + } + + /** + * Merges another same-type accumulator into this one and update its state, i.e. this should be + * merge-in-place. + */ + override def merge(other: AccumulatorV2[ColumnFamilySchema, ColumnFamilySchema]): Unit = { + _value = other.value + } + + /** + * Defines the current value of this accumulator + */ + override def value: ColumnFamilySchema = _value +} + +object ColumnFamilyAccumulator { + def create( + columnFamilyMetadata: ColumnFamilySchema, + sparkContext: SparkContext): ColumnFamilyAccumulator = { + val acc = new ColumnFamilyAccumulator(columnFamilyMetadata) + acc.register(sparkContext) + acc + } +} + +object ColumnFamilySchemaV1 { + def fromJson(json: List[Map[String, Any]]): List[ColumnFamilySchema] = { assert(json.isInstanceOf[List[_]]) json.map { colFamilyMap => - new ColumnFamilyMetadataV1( + new ColumnFamilySchemaV1( colFamilyMap("columnFamilyName").asInstanceOf[String], StructType.fromString(colFamilyMap("keySchema").asInstanceOf[String]), StructType.fromString(colFamilyMap("valueSchema").asInstanceOf[String]), @@ -122,7 +180,7 @@ object SchemaHelper { private val schemaFilePath = SchemaV3Writer.getSchemaFilePath(stateCheckpointPath) private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf) - def read: List[ColumnFamilyMetadata] = { + def read: List[ColumnFamilySchema] = { if (!fm.exists(schemaFilePath)) { return List.empty } @@ -139,7 +197,7 @@ object SchemaHelper { s"Expected List but got ${deserializedList.getClass}") val columnFamilyMetadatas = deserializedList.asInstanceOf[List[Map[String, Any]]] // Extract each JValue to StateVariableInfo - ColumnFamilyMetadataV1.fromJson(columnFamilyMetadatas) + ColumnFamilySchemaV1.fromJson(columnFamilyMetadatas) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index f50d1019887eb..b90315ab15ad6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -138,7 +138,7 @@ trait StateStore extends ReadStateStore { isInternal: Boolean = false): Unit def createColFamilyIfAbsent( - colFamilyMetadata: ColumnFamilyMetadataV1 + colFamilyMetadata: ColumnFamilySchemaV1 ): Unit /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala index aa81d7a1594d2..12795004d4859 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/MemoryStateStore.scala @@ -75,7 +75,7 @@ class MemoryStateStore extends StateStore() { throw new UnsupportedOperationException("Doesn't support multiple values per key") } - override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilyMetadataV1): Unit = { + override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilySchemaV1): Unit = { throw StateStoreErrors.removingColumnFamiliesNotSupported("MemoryStateStoreProvider") } } From 37526de0f014a4ff25af37086f783a18439a25c0 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 28 May 2024 09:03:20 -0700 Subject: [PATCH 09/11] moving getListState/getMapState methods --- .../StatefulProcessorHandleImpl.scala | 148 +++++++++--------- 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 453e4bd25f0a9..218b9a5effaa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -160,6 +160,80 @@ class StatefulProcessorHandleImpl( valueStateWithTTL } + override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { + verifyStateVarOperations("get_list_state") + incrementMetric("numListStateVars") + val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) + stateVariables.add(new StateVariableInfo(stateName, ListState, false)) + columnFamilyMetadatas.add(resultState.columnFamilyMetadata) + resultState + } + + /** + * Function to create new or return existing list state variable of given type + * with ttl. State values will not be returned past ttlDuration, and will be eventually removed + * from the state store. Any values in listState which have expired after ttlDuration will not + * returned on get() and will be eventually removed from the state. + * + * The user must ensure to call this function only within the `init()` method of the + * StatefulProcessor. + * + * @param stateName - name of the state variable + * @param valEncoder - SQL encoder for state variable + * @param ttlConfig - the ttl configuration (time to live duration etc.) + * @tparam T - type of state variable + * @return - instance of ListState of type T that can be used to store state persistently + */ + override def getListState[T]( + stateName: String, + valEncoder: Encoder[T], + ttlConfig: TTLConfig): ListState[T] = { + + verifyStateVarOperations("get_list_state") + validateTTLConfig(ttlConfig, stateName) + + assert(batchTimestampMs.isDefined) + val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, + keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) + incrementMetric("numListStateWithTTLVars") + ttlStates.add(listStateWithTTL) + stateVariables.add(new StateVariableInfo(stateName, ListState, true)) + columnFamilyMetadatas.add(listStateWithTTL.columnFamilyMetadata) + + listStateWithTTL + } + + override def getMapState[K, V]( + stateName: String, + userKeyEnc: Encoder[K], + valEncoder: Encoder[V]): MapState[K, V] = { + verifyStateVarOperations("get_map_state") + incrementMetric("numMapStateVars") + val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) + stateVariables.add(new StateVariableInfo(stateName, MapState, false)) + columnFamilyMetadatas.add(resultState.columnFamilyMetadata) + resultState + } + + override def getMapState[K, V]( + stateName: String, + userKeyEnc: Encoder[K], + valEncoder: Encoder[V], + ttlConfig: TTLConfig): MapState[K, V] = { + verifyStateVarOperations("get_map_state") + validateTTLConfig(ttlConfig, stateName) + + assert(batchTimestampMs.isDefined) + val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, + valEncoder, ttlConfig, batchTimestampMs.get) + incrementMetric("numMapStateWithTTLVars") + ttlStates.add(mapStateWithTTL) + stateVariables.add(new StateVariableInfo(stateName, MapState, true)) + columnFamilyMetadatas.add(mapStateWithTTL.columnFamilyMetadata) + + mapStateWithTTL + } + override def getQueryInfo(): QueryInfo = currQueryInfo private lazy val timerState = new TimerStateImpl(store, timeMode, keyEncoder) @@ -250,80 +324,6 @@ class StatefulProcessorHandleImpl( } } - override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { - verifyStateVarOperations("get_list_state") - incrementMetric("numListStateVars") - val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) - stateVariables.add(new StateVariableInfo(stateName, ListState, false)) - columnFamilyMetadatas.add(resultState.columnFamilyMetadata) - resultState - } - - /** - * Function to create new or return existing list state variable of given type - * with ttl. State values will not be returned past ttlDuration, and will be eventually removed - * from the state store. Any values in listState which have expired after ttlDuration will not - * returned on get() and will be eventually removed from the state. - * - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @param ttlConfig - the ttl configuration (time to live duration etc.) - * @tparam T - type of state variable - * @return - instance of ListState of type T that can be used to store state persistently - */ - override def getListState[T]( - stateName: String, - valEncoder: Encoder[T], - ttlConfig: TTLConfig): ListState[T] = { - - verifyStateVarOperations("get_list_state") - validateTTLConfig(ttlConfig, stateName) - - assert(batchTimestampMs.isDefined) - val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, - keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numListStateWithTTLVars") - ttlStates.add(listStateWithTTL) - stateVariables.add(new StateVariableInfo(stateName, ListState, true)) - columnFamilyMetadatas.add(listStateWithTTL.columnFamilyMetadata) - - listStateWithTTL - } - - override def getMapState[K, V]( - stateName: String, - userKeyEnc: Encoder[K], - valEncoder: Encoder[V]): MapState[K, V] = { - verifyStateVarOperations("get_map_state") - incrementMetric("numMapStateVars") - val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) - stateVariables.add(new StateVariableInfo(stateName, MapState, false)) - columnFamilyMetadatas.add(resultState.columnFamilyMetadata) - resultState - } - - override def getMapState[K, V]( - stateName: String, - userKeyEnc: Encoder[K], - valEncoder: Encoder[V], - ttlConfig: TTLConfig): MapState[K, V] = { - verifyStateVarOperations("get_map_state") - validateTTLConfig(ttlConfig, stateName) - - assert(batchTimestampMs.isDefined) - val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, - valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numMapStateWithTTLVars") - ttlStates.add(mapStateWithTTL) - stateVariables.add(new StateVariableInfo(stateName, MapState, true)) - columnFamilyMetadatas.add(mapStateWithTTL.columnFamilyMetadata) - - mapStateWithTTL - } - private def validateTTLConfig(ttlConfig: TTLConfig, stateName: String): Unit = { val ttlDuration = ttlConfig.ttlDuration if (timeMode != TimeMode.ProcessingTime()) { From 67acea80451a67f340b30e48a03fd8b9920b8d3c Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Mon, 3 Jun 2024 15:54:44 -0700 Subject: [PATCH 10/11] removing the columnFamilyAccumulator --- .../StatefulProcessorHandleImpl.scala | 2 +- .../streaming/TransformWithStateExec.scala | 12 ++-- .../streaming/state/SchemaHelper.scala | 60 +------------------ 3 files changed, 8 insertions(+), 66 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 218b9a5effaa4..0fad7f92c63d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -85,7 +85,7 @@ class StatefulProcessorHandleImpl( isStreaming: Boolean = true, batchTimestampMs: Option[Long] = None, metrics: Map[String, SQLMetric] = Map.empty, - existingColFamilies: Map[String, ColumnFamilyAccumulator] = Map.empty) + existingColFamilies: Map[String, ColumnFamilySchemaV1] = Map.empty) extends StatefulProcessorHandle with Logging { import StatefulProcessorHandleState._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index 6a6b4b1c9cf3f..efba8dd10b843 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -93,10 +93,10 @@ case class TransformWithStateExec( "operatorPropsFromExecutor" ) - private lazy val colFamilyAccumulators: Map[String, ColumnFamilyAccumulator] = + private lazy val colFamilySchemas: Map[String, ColumnFamilySchemaV1] = initializeColFamilyAccumulators() - private def initializeColFamilyAccumulators(): Map[String, ColumnFamilyAccumulator] = { + private def initializeColFamilyAccumulators(): Map[String, ColumnFamilySchemaV1] = { val stateCheckpointPath = new Path(stateInfo.get.checkpointLocation, getStateInfo.operatorId.toString) val hadoopConf = session.sqlContext.sessionState.newHadoopConf() @@ -104,8 +104,8 @@ case class TransformWithStateExec( val reader = new SchemaV3Reader(stateCheckpointPath, hadoopConf) reader.read.map { colFamilyMetadata => - val acc = ColumnFamilyAccumulator.create(colFamilyMetadata, sparkContext) - colFamilyMetadata.asInstanceOf[ColumnFamilySchemaV1].columnFamilyName -> acc + val schemaV1 = colFamilyMetadata.asInstanceOf[ColumnFamilySchemaV1] + schemaV1.columnFamilyName -> schemaV1 }.toMap } @@ -430,7 +430,7 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver - colFamilyAccumulators + colFamilySchemas validateTimeMode() @@ -549,7 +549,7 @@ case class TransformWithStateExec( CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( store, getStateInfo.queryRunId, keyEncoder, timeMode, - isStreaming, batchTimestampMs, metrics, colFamilyAccumulators) + isStreaming, batchTimestampMs, metrics, colFamilySchemas) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) statefulProcessor.init(outputMode, timeMode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala index 340d5c23e6cce..bc4749c0c5253 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/SchemaHelper.scala @@ -26,10 +26,9 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods import org.json4s.jackson.JsonMethods.{compact, render} -import org.apache.spark.SparkContext import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil} import org.apache.spark.sql.types.StructType -import org.apache.spark.util.{AccumulatorV2, Utils} +import org.apache.spark.util.Utils sealed trait ColumnFamilySchema extends Serializable { def jsonValue: JsonAST.JObject @@ -56,63 +55,6 @@ case class ColumnFamilySchemaV1( } } -class ColumnFamilyAccumulator( - columnFamilyMetadata: ColumnFamilySchema) extends - AccumulatorV2[ColumnFamilySchema, ColumnFamilySchema] { - - private var _value: ColumnFamilySchema = columnFamilyMetadata - /** - * Returns if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero - * value; for a list accumulator, Nil is zero value. - */ - override def isZero: Boolean = _value == null - - /** - * Creates a new copy of this accumulator. - */ - override def copy(): AccumulatorV2[ColumnFamilySchema, ColumnFamilySchema] = { - new ColumnFamilyAccumulator(_value) - } - - /** - * Resets this accumulator, which is zero value. i.e. call `isZero` must - * return true. - */ - override def reset(): Unit = { - _value = null - } - - /** - * Takes the inputs and accumulates. - */ - override def add(v: ColumnFamilySchema): Unit = { - _value = v - } - - /** - * Merges another same-type accumulator into this one and update its state, i.e. this should be - * merge-in-place. - */ - override def merge(other: AccumulatorV2[ColumnFamilySchema, ColumnFamilySchema]): Unit = { - _value = other.value - } - - /** - * Defines the current value of this accumulator - */ - override def value: ColumnFamilySchema = _value -} - -object ColumnFamilyAccumulator { - def create( - columnFamilyMetadata: ColumnFamilySchema, - sparkContext: SparkContext): ColumnFamilyAccumulator = { - val acc = new ColumnFamilyAccumulator(columnFamilyMetadata) - acc.register(sparkContext) - acc - } -} - object ColumnFamilySchemaV1 { def fromJson(json: List[Map[String, Any]]): List[ColumnFamilySchema] = { assert(json.isInstanceOf[List[_]]) From 411c5b893778584a3d4d87eb095786afefa1cf79 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 4 Jun 2024 12:05:37 -0700 Subject: [PATCH 11/11] propagating information to executor through broadcast --- .../execution/streaming/ListStateImpl.scala | 4 +- .../streaming/ListStateImplWithTTL.scala | 4 +- .../execution/streaming/MapStateImpl.scala | 4 +- .../streaming/MapStateImplWithTTL.scala | 4 +- .../StatefulProcessorHandleImpl.scala | 57 +++++++++---- .../streaming/TransformWithStateExec.scala | 21 +++-- .../execution/streaming/ValueStateImpl.scala | 4 +- .../streaming/ValueStateImplWithTTL.scala | 4 +- .../streaming/TransformWithStateSuite.scala | 83 +++++++++++++++++++ 9 files changed, 150 insertions(+), 35 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala index 9cfaa4a454b92..b74afd6f418db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImpl.scala @@ -44,9 +44,9 @@ class ListStateImpl[S]( private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - val columnFamilyMetadata = new ColumnFamilySchemaV1( + val columnFamilySchema = new ColumnFamilySchemaV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) - store.createColFamilyIfAbsent(columnFamilyMetadata) + store.createColFamilyIfAbsent(columnFamilySchema) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala index 8e69ea607546a..b5b902ab98245 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ListStateImplWithTTL.scala @@ -52,13 +52,13 @@ class ListStateImplWithTTL[S]( private lazy val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilyMetadata = new ColumnFamilySchemaV1( + val columnFamilySchema = new ColumnFamilySchemaV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), true) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilyMetadata) + store.createColFamilyIfAbsent(columnFamilySchema) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala index 2587f111a2077..b9558c1b6c310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImpl.scala @@ -34,11 +34,11 @@ class MapStateImpl[K, V]( private val stateTypesEncoder = new CompositeKeyStateEncoder( keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName) - val columnFamilyMetadata = new ColumnFamilySchemaV1( + val columnFamilySchema = new ColumnFamilySchemaV1( stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) - store.createColFamilyIfAbsent(columnFamilyMetadata) + store.createColFamilyIfAbsent(columnFamilySchema) /** Whether state exists or not. */ override def exists(): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala index 510c6870fc822..ef54624b9ecb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MapStateImplWithTTL.scala @@ -55,13 +55,13 @@ class MapStateImplWithTTL[K, V]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilyMetadata = new ColumnFamilySchemaV1( + val columnFamilySchema = new ColumnFamilySchemaV1( stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilyMetadata) + store.createColFamilyIfAbsent(columnFamilySchema) } /** Whether state exists or not. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 0fad7f92c63d1..04acd1cb7dd24 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -98,7 +98,7 @@ class StatefulProcessorHandleImpl( private[sql] val stateVariables: util.List[StateVariableInfo] = new util.ArrayList[StateVariableInfo]() - private[sql] val columnFamilyMetadatas: util.List[ColumnFamilySchema] = + private[sql] val columnFamilySchemas: util.List[ColumnFamilySchema] = new util.ArrayList[ColumnFamilySchema]() private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000" @@ -132,6 +132,19 @@ class StatefulProcessorHandleImpl( def getHandleState: StatefulProcessorHandleState = currState + def validateStateVariableCreation(newColumnFamilySchema: ColumnFamilySchemaV1): Unit = { + existingColFamilies.get( + newColumnFamilySchema.columnFamilyName).foreach { existingColFamily => + // TODO: Fill in with conditions we need to validate new state variable creation + if (existingColFamily.json != newColumnFamilySchema.json) { + throw new RuntimeException( + s"State variable with name ${newColumnFamilySchema.columnFamilyName} already exists " + + s"with different schema. Existing schema: ${existingColFamily.json}, " + + s"New schema: ${newColumnFamilySchema.json}") + } + } + } + override def getValueState[T]( stateName: String, valEncoder: Encoder[T]): ValueState[T] = { @@ -139,7 +152,9 @@ class StatefulProcessorHandleImpl( incrementMetric("numValueStateVars") val resultState = new ValueStateImpl[T](store, stateName, keyEncoder, valEncoder) stateVariables.add(new StateVariableInfo(stateName, ValueState, false)) - columnFamilyMetadatas.add(resultState.columnFamilyMetadata) + val colFamilySchema = resultState.columnFamilySchema + validateStateVariableCreation(colFamilySchema) + columnFamilySchemas.add(colFamilySchema) resultState } @@ -151,13 +166,15 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) - val valueStateWithTTL = new ValueStateImplWithTTL[T](store, stateName, + val resultState = new ValueStateImplWithTTL[T](store, stateName, keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numValueStateWithTTLVars") - ttlStates.add(valueStateWithTTL) + ttlStates.add(resultState) stateVariables.add(new StateVariableInfo(stateName, ValueState, true)) - columnFamilyMetadatas.add(valueStateWithTTL.columnFamilyMetadata) - valueStateWithTTL + val colFamilySchema = resultState.columnFamilySchema + validateStateVariableCreation(colFamilySchema) + columnFamilySchemas.add(colFamilySchema) + resultState } override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { @@ -165,7 +182,9 @@ class StatefulProcessorHandleImpl( incrementMetric("numListStateVars") val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) stateVariables.add(new StateVariableInfo(stateName, ListState, false)) - columnFamilyMetadatas.add(resultState.columnFamilyMetadata) + val colFamilySchema = resultState.columnFamilySchema + validateStateVariableCreation(colFamilySchema) + columnFamilySchemas.add(resultState.columnFamilySchema) resultState } @@ -193,14 +212,16 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) - val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, + val resultState = new ListStateImplWithTTL[T](store, stateName, keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numListStateWithTTLVars") - ttlStates.add(listStateWithTTL) + ttlStates.add(resultState) stateVariables.add(new StateVariableInfo(stateName, ListState, true)) - columnFamilyMetadatas.add(listStateWithTTL.columnFamilyMetadata) + val colFamilySchema = resultState.columnFamilySchema + validateStateVariableCreation(colFamilySchema) + columnFamilySchemas.add(resultState.columnFamilySchema) - listStateWithTTL + resultState } override def getMapState[K, V]( @@ -211,7 +232,9 @@ class StatefulProcessorHandleImpl( incrementMetric("numMapStateVars") val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) stateVariables.add(new StateVariableInfo(stateName, MapState, false)) - columnFamilyMetadatas.add(resultState.columnFamilyMetadata) + val colFamilySchema = resultState.columnFamilySchema + validateStateVariableCreation(colFamilySchema) + columnFamilySchemas.add(resultState.columnFamilySchema) resultState } @@ -224,14 +247,16 @@ class StatefulProcessorHandleImpl( validateTTLConfig(ttlConfig, stateName) assert(batchTimestampMs.isDefined) - val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, + val resultState = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder, ttlConfig, batchTimestampMs.get) incrementMetric("numMapStateWithTTLVars") - ttlStates.add(mapStateWithTTL) + ttlStates.add(resultState) stateVariables.add(new StateVariableInfo(stateName, MapState, true)) - columnFamilyMetadatas.add(mapStateWithTTL.columnFamilyMetadata) + val colFamilySchema = resultState.columnFamilySchema + validateStateVariableCreation(colFamilySchema) + columnFamilySchemas.add(resultState.columnFamilySchema) - mapStateWithTTL + resultState } override def getQueryInfo(): QueryInfo = currQueryInfo diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala index efba8dd10b843..d6805d536adf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/TransformWithStateExec.scala @@ -392,7 +392,7 @@ case class TransformWithStateExec( operatorPropertiesFromExecutor.add(Map ("columnFamilyMetadatas" -> - JArray(processorHandle.columnFamilyMetadatas. + JArray(processorHandle.columnFamilySchemas. asScala.map(_.jsonValue).toList))) } @@ -431,6 +431,7 @@ case class TransformWithStateExec( override protected def doExecute(): RDD[InternalRow] = { metrics // force lazy init at driver colFamilySchemas + val colFamilySchemasBroadcast = sparkContext.broadcast(colFamilySchemas) validateTimeMode() @@ -460,8 +461,9 @@ case class TransformWithStateExec( storeConf = storeConf, hadoopConf = hadoopConfBroadcast.value.value ) - - processDataWithInitialState(store, childDataIterator, initStateIterator) + val colFamilySchemas = colFamilySchemasBroadcast.value + processDataWithInitialState( + store, childDataIterator, initStateIterator, colFamilySchemas) } else { initNewStateStoreAndProcessData(partitionId, hadoopConfBroadcast) { store => processDataWithInitialState(store, childDataIterator, initStateIterator) @@ -481,7 +483,8 @@ case class TransformWithStateExec( useColumnFamilies = true ) { case (store: StateStore, singleIterator: Iterator[InternalRow]) => - processData(store, singleIterator) + val colFamilySchemas = colFamilySchemasBroadcast.value + processData(store, singleIterator, colFamilySchemas) } } else { // If the query is running in batch mode, we need to create a new StateStore and instantiate @@ -545,7 +548,10 @@ case class TransformWithStateExec( * @param singleIterator The iterator of rows to process * @return An iterator of rows that are the result of processing the input rows */ - private def processData(store: StateStore, singleIterator: Iterator[InternalRow]): + private def processData( + store: StateStore, + singleIterator: Iterator[InternalRow], + colFamilySchemas: Map[String, ColumnFamilySchemaV1] = Map.empty): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl( store, getStateInfo.queryRunId, keyEncoder, timeMode, @@ -560,10 +566,11 @@ case class TransformWithStateExec( private def processDataWithInitialState( store: StateStore, childDataIterator: Iterator[InternalRow], - initStateIterator: Iterator[InternalRow]): + initStateIterator: Iterator[InternalRow], + colFamilySchemas: Map[String, ColumnFamilySchemaV1] = Map.empty): CompletionIterator[InternalRow, Iterator[InternalRow]] = { val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId, - keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics) + keyEncoder, timeMode, isStreaming, batchTimestampMs, metrics, colFamilySchemas) assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED) statefulProcessor.setHandle(processorHandle) statefulProcessor.init(outputMode, timeMode) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala index 3922f41b92d60..28fb07418bf29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImpl.scala @@ -42,12 +42,12 @@ class ValueStateImpl[S]( private val keySerializer = keyExprEnc.createSerializer() private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName) - val columnFamilyMetadata = new ColumnFamilySchemaV1( + val columnFamilySchema = new ColumnFamilySchemaV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilyMetadata) + store.createColFamilyIfAbsent(columnFamilySchema) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala index a26071685b8bc..cf98697878574 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ValueStateImplWithTTL.scala @@ -49,13 +49,13 @@ class ValueStateImplWithTTL[S]( private val ttlExpirationMs = StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs) - val columnFamilyMetadata = new ColumnFamilySchemaV1( + val columnFamilySchema = new ColumnFamilySchemaV1( stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false) initialize() private def initialize(): Unit = { - store.createColFamilyIfAbsent(columnFamilyMetadata) + store.createColFamilyIfAbsent(columnFamilySchema) } /** Function to check if state exists. Returns true if present and false otherwise */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index d7cbcf505ca18..d6bfc11cca8c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.streaming import java.io.File +import java.time.Duration import java.util.UUID import org.apache.spark.SparkRuntimeException @@ -35,6 +36,36 @@ object TransformWithStateSuiteUtils { val NUM_SHUFFLE_PARTITIONS = 5 } +class RunningCountStatefulProcessorWithTTL(ttlConfig: TTLConfig) + extends StatefulProcessor[String, String, (String, String)] + with Logging { + + @transient private var _countState: ValueStateImplWithTTL[Long] = _ + + override def init( + outputMode: OutputMode, + timeMode: TimeMode): Unit = { + _countState = getHandle + .getValueState("countState", Encoders.scalaLong, ttlConfig) + .asInstanceOf[ValueStateImplWithTTL[Long]] + } + + override def handleInputRows( + key: String, + inputRows: Iterator[String], + timerValues: TimerValues, + expiredTimerInfo: ExpiredTimerInfo): Iterator[(String, String)] = { + val count = _countState.getOption().getOrElse(0L) + 1 + if (count == 3) { + _countState.clear() + Iterator.empty + } else { + _countState.update(count) + Iterator((key, count.toString)) + } + } +} + class RunningCountStatefulProcessor extends StatefulProcessor[String, String, (String, String)] with Logging { @transient protected var _countState: ValueState[Long] = _ @@ -368,6 +399,58 @@ class TransformWithStateSuite extends StateStoreMetricsTest } } + test("transformWithState - verify that query with ttl enabled after restart fails") { + withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> + classOf[RocksDBStateStoreProvider].getName) { + withTempDir { chkptDir => + val clock = new StreamManualClock + val inputData = MemoryStream[String] + val result = inputData.toDS() + .groupByKey(x => x) + .transformWithState(new RunningCountStatefulProcessorWithProcTimeTimer(), + TimeMode.ProcessingTime(), + OutputMode.Update()) + + testStream(result, OutputMode.Update())( + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = chkptDir.getCanonicalPath + ), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + CheckNewAnswer(("a", "1")), + StopStream + ) + + val result2 = inputData.toDS() + .groupByKey(x => x) + .transformWithState( + new RunningCountStatefulProcessorWithTTL(TTLConfig(Duration.ofMinutes(1))), + TimeMode.ProcessingTime(), + OutputMode.Append()) + + // verify that query with ttl enabled after restart fails + testStream(result2, OutputMode.Append())( + StartStream( + Trigger.ProcessingTime("1 second"), + triggerClock = clock, + checkpointLocation = chkptDir.getCanonicalPath + ), + AddData(inputData, "a"), + AdvanceManualClock(1 * 1000), + Execute { q => + val e = intercept[Exception] { + q.processAllAvailable() + } + assert(e.getMessage.contains("State variable with name" + + " countState already exists with different schema")) + } + ) + } + } + } + test("verify that operatorProperties contain all stateVariables") { withSQLConf(SQLConf.STATE_STORE_PROVIDER_CLASS.key -> classOf[RocksDBStateStoreProvider].getName) {