Skip to content

Commit

Permalink
renaming to ColumnFamilySchema
Browse files Browse the repository at this point in the history
  • Loading branch information
ericm-db committed Jun 3, 2024
1 parent 970cc13 commit abf8c04
Show file tree
Hide file tree
Showing 13 changed files with 107 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.ListState

/**
Expand All @@ -44,7 +44,7 @@ class ListStateImpl[S](

private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName)

val columnFamilyMetadata = new ColumnFamilyMetadataV1(
val columnFamilyMetadata = new ColumnFamilySchemaV1(
stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false)
store.createColFamilyIfAbsent(columnFamilyMetadata)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{ListState, TTLConfig}
import org.apache.spark.util.NextIterator

Expand Down Expand Up @@ -52,7 +52,7 @@ class ListStateImplWithTTL[S](
private lazy val ttlExpirationMs =
StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs)

val columnFamilyMetadata = new ColumnFamilyMetadataV1(
val columnFamilyMetadata = new ColumnFamilySchemaV1(
stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), true)
initialize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors, UnsafeRowPair}
import org.apache.spark.sql.streaming.MapState

class MapStateImpl[K, V](
Expand All @@ -34,7 +34,7 @@ class MapStateImpl[K, V](
private val stateTypesEncoder = new CompositeKeyStateEncoder(
keySerializer, userKeyEnc, valEncoder, COMPOSITE_KEY_ROW_SCHEMA, stateName)

val columnFamilyMetadata = new ColumnFamilyMetadataV1(
val columnFamilyMetadata = new ColumnFamilySchemaV1(
stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA,
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, PrefixKeyScanStateEncoderSpec, StateStore, StateStoreErrors}
import org.apache.spark.sql.streaming.{MapState, TTLConfig}
import org.apache.spark.util.NextIterator

Expand Down Expand Up @@ -55,7 +55,7 @@ class MapStateImplWithTTL[K, V](
private val ttlExpirationMs =
StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs)

val columnFamilyMetadata = new ColumnFamilyMetadataV1(
val columnFamilyMetadata = new ColumnFamilySchemaV1(
stateName, COMPOSITE_KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
PrefixKeyScanStateEncoderSpec(COMPOSITE_KEY_ROW_SCHEMA, 1), false)
initialize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ class StatefulProcessorHandleImpl(
timeMode: TimeMode,
isStreaming: Boolean = true,
batchTimestampMs: Option[Long] = None,
metrics: Map[String, SQLMetric] = Map.empty)
metrics: Map[String, SQLMetric] = Map.empty,
existingColFamilies: Map[String, ColumnFamilyAccumulator] = Map.empty)
extends StatefulProcessorHandle with Logging {
import StatefulProcessorHandleState._

Expand All @@ -97,8 +98,8 @@ class StatefulProcessorHandleImpl(
private[sql] val stateVariables: util.List[StateVariableInfo] =
new util.ArrayList[StateVariableInfo]()

private[sql] val columnFamilyMetadatas: util.List[ColumnFamilyMetadata] =
new util.ArrayList[ColumnFamilyMetadata]()
private[sql] val columnFamilyMetadatas: util.List[ColumnFamilySchema] =
new util.ArrayList[ColumnFamilySchema]()

private val BATCH_QUERY_ID = "00000000-0000-0000-0000-000000000000"

Expand Down Expand Up @@ -168,6 +169,7 @@ class StatefulProcessorHandleImpl(
throw StateStoreErrors.cannotPerformOperationWithInvalidHandleState(operationType,
currState.toString)
}

}

private def verifyTimerOperations(operationType: String): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,22 @@ case class TransformWithStateExec(
"operatorPropsFromExecutor"
)

private lazy val colFamilyAccumulators: Map[String, ColumnFamilyAccumulator] =
initializeColFamilyAccumulators()

private def initializeColFamilyAccumulators(): Map[String, ColumnFamilyAccumulator] = {
val stateCheckpointPath = new Path(stateInfo.get.checkpointLocation,
getStateInfo.operatorId.toString)
val hadoopConf = session.sqlContext.sessionState.newHadoopConf()

val reader = new SchemaV3Reader(stateCheckpointPath, hadoopConf)

reader.read.map { colFamilyMetadata =>
val acc = ColumnFamilyAccumulator.create(colFamilyMetadata, sparkContext)
colFamilyMetadata.asInstanceOf[ColumnFamilySchemaV1].columnFamilyName -> acc
}.toMap
}

/** Metadata of this stateful operator and its states stores. */
override def operatorStateMetadata(): OperatorStateMetadata = {
val info = getStateInfo
Expand Down Expand Up @@ -414,6 +430,7 @@ case class TransformWithStateExec(

override protected def doExecute(): RDD[InternalRow] = {
metrics // force lazy init at driver
colFamilyAccumulators

validateTimeMode()

Expand Down Expand Up @@ -453,10 +470,7 @@ case class TransformWithStateExec(
}
} else {
if (isStreaming) {
val stateCheckpointPath = new Path(stateInfo.get.checkpointLocation,
getStateInfo.operatorId.toString)
val hadoopConf = session.sqlContext.sessionState.newHadoopConf()
val reader = new SchemaV3Reader(stateCheckpointPath, hadoopConf)

child.execute().mapPartitionsWithStateStore[InternalRow](
getStateInfo,
KEY_ROW_SCHEMA,
Expand Down Expand Up @@ -535,7 +549,7 @@ case class TransformWithStateExec(
CompletionIterator[InternalRow, Iterator[InternalRow]] = {
val processorHandle = new StatefulProcessorHandleImpl(
store, getStateInfo.queryRunId, keyEncoder, timeMode,
isStreaming, batchTimestampMs, metrics)
isStreaming, batchTimestampMs, metrics, colFamilyAccumulators)
assert(processorHandle.getHandleState == StatefulProcessorHandleState.CREATED)
statefulProcessor.setHandle(processorHandle)
statefulProcessor.init(outputMode, timeMode)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore}
import org.apache.spark.sql.streaming.ValueState

/**
Expand All @@ -42,7 +42,7 @@ class ValueStateImpl[S](
private val keySerializer = keyExprEnc.createSerializer()
private val stateTypesEncoder = StateTypesEncoder(keySerializer, valEncoder, stateName)

val columnFamilyMetadata = new ColumnFamilyMetadataV1(
val columnFamilyMetadata = new ColumnFamilySchemaV1(
stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA, NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false)
initialize()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.execution.streaming
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.execution.streaming.TransformWithStateKeyValueRowSchema.{KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilyMetadataV1, NoPrefixKeyStateEncoderSpec, StateStore}
import org.apache.spark.sql.execution.streaming.state.{ColumnFamilySchemaV1, NoPrefixKeyStateEncoderSpec, StateStore}
import org.apache.spark.sql.streaming.{TTLConfig, ValueState}

/**
Expand Down Expand Up @@ -49,7 +49,7 @@ class ValueStateImplWithTTL[S](
private val ttlExpirationMs =
StateTTL.calculateExpirationTimeForDuration(ttlConfig.ttlDuration, batchTimestampMs)

val columnFamilyMetadata = new ColumnFamilyMetadataV1(
val columnFamilyMetadata = new ColumnFamilySchemaV1(
stateName, KEY_ROW_SCHEMA, VALUE_ROW_SCHEMA_WITH_TTL,
NoPrefixKeyStateEncoderSpec(KEY_ROW_SCHEMA), false)
initialize()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ private[sql] class HDFSBackedStateStoreProvider extends StateStoreProvider with
throw StateStoreErrors.unsupportedOperationException("merge", providerName)
}

override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilyMetadataV1): Unit = {
override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilySchemaV1): Unit = {
throw StateStoreErrors.multipleColumnFamiliesNotSupported(providerName)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ private[sql] class RocksDBStateStoreProvider
result
}

override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilyMetadataV1): Unit = {
override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilySchemaV1): Unit = {
createColFamilyIfAbsent(
colFamilyMetadata.columnFamilyName,
colFamilyMetadata.keySchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,22 +26,23 @@ import org.json4s.JsonDSL._
import org.json4s.jackson.JsonMethods
import org.json4s.jackson.JsonMethods.{compact, render}

import org.apache.spark.SparkContext
import org.apache.spark.sql.execution.streaming.{CheckpointFileManager, MetadataVersionUtil}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.Utils
import org.apache.spark.util.{AccumulatorV2, Utils}

sealed trait ColumnFamilyMetadata extends Serializable {
sealed trait ColumnFamilySchema extends Serializable {
def jsonValue: JsonAST.JObject

def json: String
}

case class ColumnFamilyMetadataV1(
case class ColumnFamilySchemaV1(
val columnFamilyName: String,
val keySchema: StructType,
val valueSchema: StructType,
val keyStateEncoderSpec: KeyStateEncoderSpec,
val multipleValuesPerKey: Boolean) extends ColumnFamilyMetadata {
val multipleValuesPerKey: Boolean) extends ColumnFamilySchema {
def jsonValue: JsonAST.JObject = {
("columnFamilyName" -> JString(columnFamilyName)) ~
("keySchema" -> keySchema.json) ~
Expand All @@ -55,12 +56,69 @@ case class ColumnFamilyMetadataV1(
}
}

object ColumnFamilyMetadataV1 {
def fromJson(json: List[Map[String, Any]]): List[ColumnFamilyMetadata] = {
class ColumnFamilyAccumulator(
columnFamilyMetadata: ColumnFamilySchema) extends
AccumulatorV2[ColumnFamilySchema, ColumnFamilySchema] {

private var _value: ColumnFamilySchema = columnFamilyMetadata
/**
* Returns if this accumulator is zero value or not. e.g. for a counter accumulator, 0 is zero
* value; for a list accumulator, Nil is zero value.
*/
override def isZero: Boolean = _value == null

/**
* Creates a new copy of this accumulator.
*/
override def copy(): AccumulatorV2[ColumnFamilySchema, ColumnFamilySchema] = {
new ColumnFamilyAccumulator(_value)
}

/**
* Resets this accumulator, which is zero value. i.e. call `isZero` must
* return true.
*/
override def reset(): Unit = {
_value = null
}

/**
* Takes the inputs and accumulates.
*/
override def add(v: ColumnFamilySchema): Unit = {
_value = v
}

/**
* Merges another same-type accumulator into this one and update its state, i.e. this should be
* merge-in-place.
*/
override def merge(other: AccumulatorV2[ColumnFamilySchema, ColumnFamilySchema]): Unit = {
_value = other.value
}

/**
* Defines the current value of this accumulator
*/
override def value: ColumnFamilySchema = _value
}

object ColumnFamilyAccumulator {
def create(
columnFamilyMetadata: ColumnFamilySchema,
sparkContext: SparkContext): ColumnFamilyAccumulator = {
val acc = new ColumnFamilyAccumulator(columnFamilyMetadata)
acc.register(sparkContext)
acc
}
}

object ColumnFamilySchemaV1 {
def fromJson(json: List[Map[String, Any]]): List[ColumnFamilySchema] = {
assert(json.isInstanceOf[List[_]])

json.map { colFamilyMap =>
new ColumnFamilyMetadataV1(
new ColumnFamilySchemaV1(
colFamilyMap("columnFamilyName").asInstanceOf[String],
StructType.fromString(colFamilyMap("keySchema").asInstanceOf[String]),
StructType.fromString(colFamilyMap("valueSchema").asInstanceOf[String]),
Expand Down Expand Up @@ -122,7 +180,7 @@ object SchemaHelper {
private val schemaFilePath = SchemaV3Writer.getSchemaFilePath(stateCheckpointPath)

private lazy val fm = CheckpointFileManager.create(stateCheckpointPath, hadoopConf)
def read: List[ColumnFamilyMetadata] = {
def read: List[ColumnFamilySchema] = {
if (!fm.exists(schemaFilePath)) {
return List.empty
}
Expand All @@ -139,7 +197,7 @@ object SchemaHelper {
s"Expected List but got ${deserializedList.getClass}")
val columnFamilyMetadatas = deserializedList.asInstanceOf[List[Map[String, Any]]]
// Extract each JValue to StateVariableInfo
ColumnFamilyMetadataV1.fromJson(columnFamilyMetadatas)
ColumnFamilySchemaV1.fromJson(columnFamilyMetadatas)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ trait StateStore extends ReadStateStore {
isInternal: Boolean = false): Unit

def createColFamilyIfAbsent(
colFamilyMetadata: ColumnFamilyMetadataV1
colFamilyMetadata: ColumnFamilySchemaV1
): Unit

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class MemoryStateStore extends StateStore() {
throw new UnsupportedOperationException("Doesn't support multiple values per key")
}

override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilyMetadataV1): Unit = {
override def createColFamilyIfAbsent(colFamilyMetadata: ColumnFamilySchemaV1): Unit = {
throw StateStoreErrors.removingColumnFamiliesNotSupported("MemoryStateStoreProvider")
}
}

0 comments on commit abf8c04

Please sign in to comment.