Skip to content

Commit

Permalink
[Spark]setThroughputControlGroupNameOnRequest (#34702)
Browse files Browse the repository at this point in the history
* change to use throughput control group on request options

* Update CosmosConfig.scala

* Update ItemsPartitionReader.scala

---------

Co-authored-by: annie-mac <[email protected]>
Co-authored-by: Fabian Meiswinkel <[email protected]>
  • Loading branch information
3 people authored Apr 28, 2023
1 parent abf4482 commit 6257676
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
package com.azure.cosmos.spark

// scalastyle:off underscore.import
import com.azure.cosmos.implementation.HttpConstants
import com.azure.cosmos.implementation.apachecommons.lang.StringUtils
import com.azure.cosmos.{models, _}
import com.azure.cosmos.models._
import com.azure.cosmos.spark.BulkWriter.{BulkOperationFailedException, bulkWriterBoundedElastic, getThreadInfo}
import com.azure.cosmos.spark.diagnostics.DefaultDiagnostics
import com.azure.cosmos._
import reactor.core.scheduler.Scheduler

import scala.collection.mutable
Expand All @@ -30,9 +29,9 @@ import reactor.core.scala.publisher.{SFlux, SMono}
import reactor.core.scheduler.Schedulers

import java.util.UUID
import java.util.concurrent.{Semaphore, TimeUnit}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong, AtomicReference}
import java.util.concurrent.locks.ReentrantLock
import java.util.concurrent.{Semaphore, TimeUnit}
// scalastyle:off underscore.import
import scala.collection.JavaConverters._
// scalastyle:on underscore.import
Expand Down Expand Up @@ -92,6 +91,7 @@ class BulkWriter(container: CosmosAsyncContainer,
new CosmosBulkExecutionOptions(BulkWriter.bulkProcessingThresholds),
maxConcurrentPartitions
)
ThroughputControlHelper.populateThroughputControlGroupName(cosmosBulkExecutionOptions, writeConfig.throughputControlConfig)

private val operationContext = initializeOperationContext()
private val cosmosPatchHelperOpt = writeConfig.itemWriteStrategy match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ private case class ChangeFeedPartitionReader
val options = CosmosChangeFeedRequestOptions
.createForProcessingFromContinuation(this.partition.continuationState.get)
.setMaxItemCount(readConfig.maxItemCount)
ThroughputControlHelper.populateThroughputControlGroupName(options, readConfig.throughputControlConfig)

var factoryMethod: java.util.function.Function[JsonNode, _] = (_: JsonNode) => {}
cosmosChangeFeedConfig.changeFeedMode match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,8 @@ private case class CosmosReadConfig(forceEventualConsistency: Boolean,
maxItemCount: Int,
prefetchBufferSize: Int,
dedicatedGatewayRequestOptions: DedicatedGatewayRequestOptions,
customQuery: Option[CosmosParameterizedQuery])
customQuery: Option[CosmosParameterizedQuery],
throughputControlConfig: Option[CosmosThroughputControlConfig] = None)

private object SchemaConversionModes extends Enumeration {
type SchemaConversionMode = Value
Expand Down Expand Up @@ -665,6 +666,7 @@ private object CosmosReadConfig {
result
}

val throughputControlConfigOpt = CosmosThroughputControlConfig.parseThroughputControlConfig(cfg)

CosmosReadConfig(
forceEventualConsistency.get,
Expand All @@ -682,7 +684,8 @@ private object CosmosReadConfig {
}
),
dedicatedGatewayRequestOptions,
customQuery)
customQuery,
throughputControlConfigOpt)
}
}

Expand Down Expand Up @@ -805,7 +808,8 @@ private case class CosmosWriteConfig(itemWriteStrategy: ItemWriteStrategy,
bulkMaxPendingOperations: Option[Int] = None,
pointMaxConcurrency: Option[Int] = None,
maxConcurrentCosmosPartitions: Option[Int] = None,
patchConfigs: Option[CosmosPatchConfigs] = None)
patchConfigs: Option[CosmosPatchConfigs] = None,
throughputControlConfig: Option[CosmosThroughputControlConfig] = None)

private object CosmosWriteConfig {
private val DefaultMaxRetryCount = 10
Expand Down Expand Up @@ -955,6 +959,7 @@ private object CosmosWriteConfig {
val maxRetryCountOpt = CosmosConfigEntry.parse(cfg, maxRetryCount)
val bulkEnabledOpt = CosmosConfigEntry.parse(cfg, bulkEnabled)
var patchConfigsOpt = Option.empty[CosmosPatchConfigs]
val throughputControlConfigOpt = CosmosThroughputControlConfig.parseThroughputControlConfig(cfg)

assert(bulkEnabledOpt.isDefined)

Expand All @@ -978,7 +983,8 @@ private object CosmosWriteConfig {
bulkMaxPendingOperations = CosmosConfigEntry.parse(cfg, bulkMaxPendingOperations),
pointMaxConcurrency = CosmosConfigEntry.parse(cfg, pointWriteConcurrency),
maxConcurrentCosmosPartitions = CosmosConfigEntry.parse(cfg, bulkMaxConcurrentPartitions),
patchConfigs = patchConfigsOpt)
patchConfigs = patchConfigsOpt,
throughputControlConfig = throughputControlConfigOpt)
}

def parsePatchColumnConfigs(cfg: Map[String, String], inputSchema: StructType): TrieMap[String, CosmosPatchColumnConfig] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ private object CosmosTableSchemaInferrer
val queryOptions = new CosmosQueryRequestOptions()
queryOptions.setMaxBufferedItemCount(cosmosInferenceConfig.inferSchemaSamplingSize)
queryOptions.setDedicatedGatewayRequestOptions(cosmosReadConfig.dedicatedGatewayRequestOptions)
ThroughputControlHelper.populateThroughputControlGroupName(queryOptions, cosmosReadConfig.throughputControlConfig)

val queryText = cosmosInferenceConfig.inferSchemaQuery match {
case None =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ private case class ItemsPartitionReader
.CosmosQueryRequestOptionsHelper
.getCosmosQueryRequestOptionsAccessor
.disallowQueryPlanRetrieval(new CosmosQueryRequestOptions())

private val readConfig = CosmosReadConfig.parseCosmosReadConfig(config)
ThroughputControlHelper.populateThroughputControlGroupName(queryOptions, readConfig.throughputControlConfig)

private val operationContext = initializeOperationContext()

Expand All @@ -50,9 +53,6 @@ private case class ItemsPartitionReader
s"correlationActivityId ${diagnosticsContext.correlationActivityId}, " +
s"query: ${cosmosQuery.toString}, Context: ${operationContext.toString} ${getThreadInfo}")

private val readConfig = CosmosReadConfig.parseCosmosReadConfig(config)


private val clientCacheItem = CosmosClientCache(
CosmosClientConfiguration(config, readConfig.forceEventualConsistency),
Some(cosmosClientStateHandles.value.cosmosClientMetadataCaches),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ class PointWriter(container: CosmosAsyncContainer,
.getCosmosItemRequestOptionsAccessor
.setOperationContext(itemOption, operationContextAndListenerTuple)
}

ThroughputControlHelper.populateThroughputControlGroupName(itemOption, cosmosWriteConfig.throughputControlConfig)
itemOption
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package com.azure.cosmos.spark

import com.azure.cosmos.implementation.ImplementationBridgeHelpers
import com.azure.cosmos.models.{CosmosBulkExecutionOptions, CosmosChangeFeedRequestOptions, CosmosItemRequestOptions, CosmosQueryRequestOptions}
import com.azure.cosmos.spark.diagnostics.BasicLoggingTrait
import com.azure.cosmos.{CosmosAsyncContainer, ThroughputControlGroupConfigBuilder}
import org.apache.spark.broadcast.Broadcast
Expand Down Expand Up @@ -47,6 +48,42 @@ private object ThroughputControlHelper extends BasicLoggingTrait {
container
}

def populateThroughputControlGroupName(
bulkExecutionOptions: CosmosBulkExecutionOptions,
throughputControlConfigOpt: Option[CosmosThroughputControlConfig]
): Unit = {
if (throughputControlConfigOpt.isDefined) {
bulkExecutionOptions.setThroughputControlGroupName(throughputControlConfigOpt.get.groupName)
}
}

def populateThroughputControlGroupName(
itemRequestOptions: CosmosItemRequestOptions,
throughputControlConfigOpt: Option[CosmosThroughputControlConfig]
): Unit = {
if (throughputControlConfigOpt.isDefined) {
itemRequestOptions.setThroughputControlGroupName(throughputControlConfigOpt.get.groupName)
}
}

def populateThroughputControlGroupName(
queryRequestOptions: CosmosQueryRequestOptions,
throughputControlConfigOpt: Option[CosmosThroughputControlConfig]
): Unit = {
if (throughputControlConfigOpt.isDefined) {
queryRequestOptions.setThroughputControlGroupName(throughputControlConfigOpt.get.groupName)
}
}

def populateThroughputControlGroupName(
changeFeedRequestOptions: CosmosChangeFeedRequestOptions,
throughputControlConfigOpt: Option[CosmosThroughputControlConfig]
): Unit = {
if (throughputControlConfigOpt.isDefined) {
changeFeedRequestOptions.setThroughputControlGroupName(throughputControlConfigOpt.get.groupName)
}
}

private def enableGlobalThroughputControlGroup(
userConfig: Map[String, String],
cosmosContainerConfig: CosmosContainerConfig,
Expand All @@ -56,7 +93,6 @@ private object ThroughputControlHelper extends BasicLoggingTrait {
throughputControlConfig: CosmosThroughputControlConfig): Unit = {
val groupConfigBuilder = new ThroughputControlGroupConfigBuilder()
.groupName(throughputControlConfig.groupName)
.defaultControlGroup(true)

if (throughputControlConfig.targetThroughput.isDefined) {
groupConfigBuilder.targetThroughput(throughputControlConfig.targetThroughput.get)
Expand Down Expand Up @@ -103,7 +139,6 @@ private object ThroughputControlHelper extends BasicLoggingTrait {

val groupConfigBuilder = new ThroughputControlGroupConfigBuilder()
.groupName(throughputControlConfig.groupName)
.defaultControlGroup(true)

// If there is no SparkExecutorCount being captured, then fall back to use 1 executor count
// If the spark executor count is somehow 0, then fall back to 1 executor count
Expand Down

0 comments on commit 6257676

Please sign in to comment.