From 37526de0f014a4ff25af37086f783a18439a25c0 Mon Sep 17 00:00:00 2001 From: Eric Marnadi Date: Tue, 28 May 2024 09:03:20 -0700 Subject: [PATCH] moving getListState/getMapState methods --- .../StatefulProcessorHandleImpl.scala | 148 +++++++++--------- 1 file changed, 74 insertions(+), 74 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala index 453e4bd25f0a9..218b9a5effaa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulProcessorHandleImpl.scala @@ -160,6 +160,80 @@ class StatefulProcessorHandleImpl( valueStateWithTTL } + override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { + verifyStateVarOperations("get_list_state") + incrementMetric("numListStateVars") + val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) + stateVariables.add(new StateVariableInfo(stateName, ListState, false)) + columnFamilyMetadatas.add(resultState.columnFamilyMetadata) + resultState + } + + /** + * Function to create new or return existing list state variable of given type + * with ttl. State values will not be returned past ttlDuration, and will be eventually removed + * from the state store. Any values in listState which have expired after ttlDuration will not + * returned on get() and will be eventually removed from the state. + * + * The user must ensure to call this function only within the `init()` method of the + * StatefulProcessor. + * + * @param stateName - name of the state variable + * @param valEncoder - SQL encoder for state variable + * @param ttlConfig - the ttl configuration (time to live duration etc.) + * @tparam T - type of state variable + * @return - instance of ListState of type T that can be used to store state persistently + */ + override def getListState[T]( + stateName: String, + valEncoder: Encoder[T], + ttlConfig: TTLConfig): ListState[T] = { + + verifyStateVarOperations("get_list_state") + validateTTLConfig(ttlConfig, stateName) + + assert(batchTimestampMs.isDefined) + val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, + keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) + incrementMetric("numListStateWithTTLVars") + ttlStates.add(listStateWithTTL) + stateVariables.add(new StateVariableInfo(stateName, ListState, true)) + columnFamilyMetadatas.add(listStateWithTTL.columnFamilyMetadata) + + listStateWithTTL + } + + override def getMapState[K, V]( + stateName: String, + userKeyEnc: Encoder[K], + valEncoder: Encoder[V]): MapState[K, V] = { + verifyStateVarOperations("get_map_state") + incrementMetric("numMapStateVars") + val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) + stateVariables.add(new StateVariableInfo(stateName, MapState, false)) + columnFamilyMetadatas.add(resultState.columnFamilyMetadata) + resultState + } + + override def getMapState[K, V]( + stateName: String, + userKeyEnc: Encoder[K], + valEncoder: Encoder[V], + ttlConfig: TTLConfig): MapState[K, V] = { + verifyStateVarOperations("get_map_state") + validateTTLConfig(ttlConfig, stateName) + + assert(batchTimestampMs.isDefined) + val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, + valEncoder, ttlConfig, batchTimestampMs.get) + incrementMetric("numMapStateWithTTLVars") + ttlStates.add(mapStateWithTTL) + stateVariables.add(new StateVariableInfo(stateName, MapState, true)) + columnFamilyMetadatas.add(mapStateWithTTL.columnFamilyMetadata) + + mapStateWithTTL + } + override def getQueryInfo(): QueryInfo = currQueryInfo private lazy val timerState = new TimerStateImpl(store, timeMode, keyEncoder) @@ -250,80 +324,6 @@ class StatefulProcessorHandleImpl( } } - override def getListState[T](stateName: String, valEncoder: Encoder[T]): ListState[T] = { - verifyStateVarOperations("get_list_state") - incrementMetric("numListStateVars") - val resultState = new ListStateImpl[T](store, stateName, keyEncoder, valEncoder) - stateVariables.add(new StateVariableInfo(stateName, ListState, false)) - columnFamilyMetadatas.add(resultState.columnFamilyMetadata) - resultState - } - - /** - * Function to create new or return existing list state variable of given type - * with ttl. State values will not be returned past ttlDuration, and will be eventually removed - * from the state store. Any values in listState which have expired after ttlDuration will not - * returned on get() and will be eventually removed from the state. - * - * The user must ensure to call this function only within the `init()` method of the - * StatefulProcessor. - * - * @param stateName - name of the state variable - * @param valEncoder - SQL encoder for state variable - * @param ttlConfig - the ttl configuration (time to live duration etc.) - * @tparam T - type of state variable - * @return - instance of ListState of type T that can be used to store state persistently - */ - override def getListState[T]( - stateName: String, - valEncoder: Encoder[T], - ttlConfig: TTLConfig): ListState[T] = { - - verifyStateVarOperations("get_list_state") - validateTTLConfig(ttlConfig, stateName) - - assert(batchTimestampMs.isDefined) - val listStateWithTTL = new ListStateImplWithTTL[T](store, stateName, - keyEncoder, valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numListStateWithTTLVars") - ttlStates.add(listStateWithTTL) - stateVariables.add(new StateVariableInfo(stateName, ListState, true)) - columnFamilyMetadatas.add(listStateWithTTL.columnFamilyMetadata) - - listStateWithTTL - } - - override def getMapState[K, V]( - stateName: String, - userKeyEnc: Encoder[K], - valEncoder: Encoder[V]): MapState[K, V] = { - verifyStateVarOperations("get_map_state") - incrementMetric("numMapStateVars") - val resultState = new MapStateImpl[K, V](store, stateName, keyEncoder, userKeyEnc, valEncoder) - stateVariables.add(new StateVariableInfo(stateName, MapState, false)) - columnFamilyMetadatas.add(resultState.columnFamilyMetadata) - resultState - } - - override def getMapState[K, V]( - stateName: String, - userKeyEnc: Encoder[K], - valEncoder: Encoder[V], - ttlConfig: TTLConfig): MapState[K, V] = { - verifyStateVarOperations("get_map_state") - validateTTLConfig(ttlConfig, stateName) - - assert(batchTimestampMs.isDefined) - val mapStateWithTTL = new MapStateImplWithTTL[K, V](store, stateName, keyEncoder, userKeyEnc, - valEncoder, ttlConfig, batchTimestampMs.get) - incrementMetric("numMapStateWithTTLVars") - ttlStates.add(mapStateWithTTL) - stateVariables.add(new StateVariableInfo(stateName, MapState, true)) - columnFamilyMetadatas.add(mapStateWithTTL.columnFamilyMetadata) - - mapStateWithTTL - } - private def validateTTLConfig(ttlConfig: TTLConfig, stateName: String): Unit = { val ttlDuration = ttlConfig.ttlDuration if (timeMode != TimeMode.ProcessingTime()) {