Skip to content

Commit

Permalink
Adjust MicroBatchSize dynamically based on throttling rate in BulkExe…
Browse files Browse the repository at this point in the history
…cutor (#22290)

* Temp snapshot

* Adjusting MicroBatchSize dynamically in BulkExecutor.java

* Making sure Bulk Request 429 bubble up to the BulkExecutor so they are accounted for in dynamic MicroBatchSize adjustment

* Adjusting targeted bulk throttling retry rate to be a range

* Reducing lock contention in PartitionScopeThresholds.java

* Adding unit test coverage for dynamically changing micro batch size in BulkExecutor

* Adjusting log level in PartitionScopeThresholds

* Moving new API to V4_17_0 Beta annotation

* Adding missing copyright header

* Removing 408 special-casing

* Reacting to code review feedback

* Reacting to code review feedback

* Reenabling Direct tests

* Fixing a bug in the new buffering logic causing 400-BadRequest when the Batch request contains no actual operations after filtering out the dummy FlushOperations

* Fixing type

* Fixes for merge conflicts

* Dummy

* Update BulkWriter.scala

* Update BulkProcessingThresholds.java

* Reverting BridgeInternal changes

* Update BridgeInternal.java

* Update BulkProcessingOptionsTest.java

* Triggering flush on completion of input flux

* Self-code review feedback :-)

* Update BulkProcessingThresholds.java

* Fixing Nullref in BulkWriterTest

* Making FlushBuffersItemOperation a singleton

* Fixing build break

* Fixing test failure
  • Loading branch information
FabianMeiswinkel authored Jun 23, 2021
1 parent ab037df commit ddafb5f
Show file tree
Hide file tree
Showing 24 changed files with 887 additions and 155 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,23 @@
// Licensed under the MIT License.
package com.azure.cosmos.spark

import com.azure.cosmos.
{
BulkItemRequestOptions,
BulkOperations,
BulkProcessingOptions,
BulkProcessingThresholds,
CosmosAsyncContainer,
CosmosBulkOperationResponse,
CosmosException,
CosmosItemOperation
}
import com.azure.cosmos.implementation.ImplementationBridgeHelpers
import com.azure.cosmos.implementation.guava25.base.Preconditions
import com.azure.cosmos.implementation.spark.{OperationContextAndListenerTuple, OperationListener}
import com.azure.cosmos.models.PartitionKey
import com.azure.cosmos.spark.BulkWriter.{DefaultMaxPendingOperationPerCore, emitFailureHandler}
import com.azure.cosmos.spark.diagnostics.{DiagnosticsContext, DiagnosticsLoader, LoggerHelper, SparkTaskContext}
import com.azure.cosmos.{
BulkItemRequestOptions,
BulkOperations, BulkProcessingOptions, CosmosAsyncContainer, CosmosBulkOperationResponse, CosmosException, CosmosItemOperation
}
import com.fasterxml.jackson.databind.node.ObjectNode
import org.apache.spark.TaskContext
import reactor.core.Disposable
Expand Down Expand Up @@ -69,34 +76,46 @@ class BulkWriter(container: CosmosAsyncContainer,
private val totalScheduledMetrics = new AtomicLong(0)
private val totalSuccessfulIngestionMetrics = new AtomicLong(0)

private val bulkOptions = new BulkProcessingOptions[Object]()
initializeDiagnosticsIfConfigured()
private val bulkOptions = new BulkProcessingOptions[Object](null, BulkWriter.bulkProcessingThresholds)
private val operationContext = initializeOperationContext()

private def initializeDiagnosticsIfConfigured(): Unit = {
if (diagnosticsConfig.mode.isDefined) {
val taskContext = TaskContext.get
assert(taskContext != null)
private def initializeOperationContext(): SparkTaskContext = {
val taskContext = TaskContext.get

val diagnosticsContext: DiagnosticsContext = DiagnosticsContext(UUID.randomUUID().toString, "BulkWriter")
val diagnosticsContext: DiagnosticsContext = DiagnosticsContext(UUID.randomUUID().toString, "BulkWriter")

if (taskContext != null) {
val taskDiagnosticsContext = SparkTaskContext(diagnosticsContext.correlationActivityId,
taskContext.stageId(),
taskContext.partitionId(),
taskContext.taskAttemptId(),
"")

val listener: OperationListener =
DiagnosticsLoader.getDiagnosticsProvider(diagnosticsConfig).getLogger(this.getClass)

val operationContextAndListenerTuple = new OperationContextAndListenerTuple(taskDiagnosticsContext, listener)
ImplementationBridgeHelpers.CosmosBulkProcessingOptionsHelper
.getCosmosBulkProcessingOptionAccessor()
.getCosmosBulkProcessingOptionAccessor
.setOperationContext(bulkOptions, operationContextAndListenerTuple)

taskDiagnosticsContext
} else{
SparkTaskContext(diagnosticsContext.correlationActivityId,
-1,
-1,
-1,
"")
}
}

private val subscriptionDisposable: Disposable = {
val bulkOperationResponseFlux: SFlux[CosmosBulkOperationResponse[Object]] =
container.processBulkOperations[Object](bulkInputEmitter.asFlux(), bulkOptions).asScala
container
.processBulkOperations[Object](
bulkInputEmitter.asFlux(),
bulkOptions)
.asScala

bulkOperationResponseFlux.subscribe(
resp => {
Expand All @@ -109,18 +128,18 @@ class BulkWriter(container: CosmosAsyncContainer,

if (resp.getException != null) {
Option(resp.getException) match {
case Some(cosmosException: CosmosException) => {
log.logDebug(s"encountered ${cosmosException.getStatusCode}")
case Some(cosmosException: CosmosException) =>
log.logDebug(s"encountered ${cosmosException.getStatusCode}, Context: ${operationContext.toString}")
if (shouldIgnore(cosmosException)) {
log.logDebug(s"for itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}], " +
s"ignored encountered ${cosmosException.getStatusCode}")
s"ignored encountered ${cosmosException.getStatusCode}, Context: ${operationContext.toString}")
totalSuccessfulIngestionMetrics.getAndIncrement()
// work done
} else if (shouldRetry(cosmosException, contextOpt.get)) {
// requeue
log.logWarning(s"for itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}], " +
s"encountered ${cosmosException.getStatusCode}, will retry! " +
s"attemptNumber=${context.attemptNumber}, exceptionMessage=${cosmosException.getMessage}")
s"attemptNumber=${context.attemptNumber}, exceptionMessage=${cosmosException.getMessage}, Context: {${operationContext.toString}}")

// this is to ensure the submission will happen on a different thread in background
// and doesn't block the active thread
Expand All @@ -136,14 +155,14 @@ class BulkWriter(container: CosmosAsyncContainer,
} else {
log.logWarning(s"for itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}], " +
s"encountered ${cosmosException.getStatusCode}, all retries exhausted! " +
s"attemptNumber=${context.attemptNumber}, exceptionMessage=${cosmosException.getMessage}")
s"attemptNumber=${context.attemptNumber}, exceptionMessage=${cosmosException.getMessage}, Context: {${operationContext.toString}")
captureIfFirstFailure(cosmosException)
cancelWork()
}
}
case _ =>
log.logWarning(s"unexpected failure: itemId=[${context.itemId}], partitionKeyValue=[${context.partitionKeyValue}], " +
s"encountered , attemptNumber=${context.attemptNumber}, exceptionMessage=${resp.getException.getMessage}", resp.getException)
s"encountered , attemptNumber=${context.attemptNumber}, exceptionMessage=${resp.getException.getMessage}, " +
s"Context: ${operationContext.toString}", resp.getException)
captureIfFirstFailure(resp.getException)
cancelWork()
}
Expand All @@ -163,7 +182,7 @@ class BulkWriter(container: CosmosAsyncContainer,
},
errorConsumer = Option.apply(
ex => {
log.logError("Unexpected failure code path in Bulk ingestion", ex)
log.logError(s"Unexpected failure code path in Bulk ingestion, Context: ${operationContext.toString}", ex)
// if there is any failure this closes the bulk.
// at this point bulk api doesn't allow any retrying
// we don't know the list of failed item-operations
Expand All @@ -182,21 +201,21 @@ class BulkWriter(container: CosmosAsyncContainer,
override def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode): Unit = {
Preconditions.checkState(!closed.get())
if (errorCaptureFirstException.get() != null) {
log.logWarning("encountered failure earlier, rejecting new work")
log.logWarning(s"encountered failure earlier, rejecting new work, Context: ${operationContext.toString}")
throw errorCaptureFirstException.get()
}

semaphore.acquire()
val cnt = totalScheduledMetrics.getAndIncrement()
log.logDebug(s"total scheduled ${cnt}")
log.logDebug(s"total scheduled $cnt, Context: ${operationContext.toString}")

scheduleWriteInternal(partitionKeyValue, objectNode, OperationContext(getId(objectNode), partitionKeyValue, getETag(objectNode), 1))
}

private def scheduleWriteInternal(partitionKeyValue: PartitionKey, objectNode: ObjectNode, operationContext: OperationContext): Unit = {
activeTasks.incrementAndGet()
if (operationContext.attemptNumber > 1) {
log.logInfo(s"bulk scheduleWrite attemptCnt: ${operationContext.attemptNumber}")
log.logInfo(s"bulk scheduleWrite attemptCnt: ${operationContext.attemptNumber}, Context: ${operationContext.toString}")
}

val bulkItemOperation = writeConfig.itemWriteStrategy match {
Expand Down Expand Up @@ -226,46 +245,49 @@ class BulkWriter(container: CosmosAsyncContainer,

// the caller has to ensure that after invoking this method scheduleWrite doesn't get invoked
override def flushAndClose(): Unit = {
this.synchronized{
this.synchronized {
try {
if (closed.get()) {
// scalastyle:off return
return
// scalastyle:on return
}

log.logInfo("flushAndClose invoked")

log.logInfo(s"completed so far ${totalSuccessfulIngestionMetrics.get()}, pending tasks ${activeOperations.size}")
log.logInfo(s"flushAndClose invoked, Context: ${operationContext.toString}")
log.logInfo(s"completed so far ${totalSuccessfulIngestionMetrics.get()}, pending tasks ${activeOperations.size}, Context: ${operationContext.toString}")

// error handling, if there is any error and the subscription is cancelled
// the remaining tasks will not be processed hence we never reach activeTasks = 0
// once we do error handling we should think how to cover the scenario.
lock.lock()
try {
while (activeTasks.get() > 0 || errorCaptureFirstException.get != null) {
var activeTasksSnapshot = activeTasks.get()
while (activeTasksSnapshot > 0 || errorCaptureFirstException.get != null) {
log.logDebug(s"Waiting for pending activeTasks $activeTasksSnapshot, Context: ${operationContext.toString}")
pendingTasksCompleted.await()
activeTasksSnapshot = activeTasks.get()
log.logDebug(s"Waiting completed for pending activeTasks $activeTasksSnapshot, Context: ${operationContext.toString}")
}
} finally {
lock.unlock()
}

log.logInfo("invoking bulkInputEmitter.onComplete()")
log.logInfo(s"invoking bulkInputEmitter.onComplete(), Context: ${operationContext.toString}")
semaphore.release(activeTasks.get())
bulkInputEmitter.tryEmitComplete()

// which error to report?
if (errorCaptureFirstException.get() != null) {
log.logError(s"flushAndClose throw captured error ${errorCaptureFirstException.get().getMessage}")
log.logError(s"flushAndClose throw captured error ${errorCaptureFirstException.get().getMessage}, " +
s"Context: ${operationContext.toString}")
throw errorCaptureFirstException.get()
}

assume(activeTasks.get() == 0)
assume(activeOperations.isEmpty)
assume(semaphore.availablePermits() == maxPendingOperations)

log.logInfo(s"flushAndClose completed with no error. " +
s"totalSuccessfulIngestionMetrics=${totalSuccessfulIngestionMetrics.get()}, totalScheduled=${totalScheduledMetrics}")
s"totalSuccessfulIngestionMetrics=${totalSuccessfulIngestionMetrics.get()}, " +
s"totalScheduled=$totalScheduledMetrics, Context: ${operationContext.toString}")
assume(totalScheduledMetrics.get() == totalSuccessfulIngestionMetrics.get)
} finally {
closed.set(true)
Expand All @@ -276,16 +298,20 @@ class BulkWriter(container: CosmosAsyncContainer,
private def markTaskCompletion(): Unit = {
lock.lock()
try {
if (activeTasks.decrementAndGet() == 0 || errorCaptureFirstException.get() != null) {
val activeTasksLeftSnapshot = activeTasks.decrementAndGet()
val exceptionSnapshot = errorCaptureFirstException.get()
if (activeTasksLeftSnapshot == 0 || exceptionSnapshot != null) {
log.logDebug(s"MarkTaskCompletion, Active tasks left: $activeTasksLeftSnapshot, " +
s"error: $exceptionSnapshot, Context: ${operationContext.toString}")
pendingTasksCompleted.signal()
}
} finally {
lock.unlock()
}
}

private def captureIfFirstFailure(throwable: Throwable) = {
log.logError("capture failure", throwable)
private def captureIfFirstFailure(throwable: Throwable): Unit = {
log.logError(s"capture failure, Context: {${operationContext.toString}}", throwable)
lock.lock()
try {
errorCaptureFirstException.compareAndSet(null, throwable)
Expand All @@ -296,7 +322,8 @@ class BulkWriter(container: CosmosAsyncContainer,
}

private def cancelWork(): Unit = {
log.logInfo(s"cancelling remaining un process tasks ${activeTasks.get}")
log.logInfo(s"cancelling remaining unprocessed tasks ${activeTasks.get}, " +
s"Context: ${operationContext.toString}")
subscriptionDisposable.dispose()
}

Expand Down Expand Up @@ -341,10 +368,12 @@ private object BulkWriter {
// hence we want 2MB/ 1KB items per partition to be buffered
// 2 * 1024 * 167 items should get buffered on a 16 CPU core VM
// so per CPU core we want (2 * 1024 * 167 / 16) max items to be buffered
val DefaultMaxPendingOperationPerCore = 2 * 1024 * 167 / 16
val DefaultMaxPendingOperationPerCore: Int = 2 * 1024 * 167 / 16

val emitFailureHandler: EmitFailureHandler =
(signalType, emitResult) => if (emitResult.equals(EmitResult.FAIL_NON_SERIALIZED)) true else false
(_, emitResult) => if (emitResult.equals(EmitResult.FAIL_NON_SERIALIZED)) true else false

val bulkProcessingThresholds = new BulkProcessingThresholds[Object]()
}

//scalastyle:on multiple.string.literals
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ private case class ItemsPartitionReader
val taskDiagnosticsContext = SparkTaskContext(diagnosticsContext.correlationActivityId,
taskContext.stageId(),
taskContext.partitionId(),
taskContext.taskAttemptId(),
feedRange.toString + " " + cosmosQuery.toSqlQuerySpec.getQueryText)

val listener: OperationListener =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class PointWriter(container: CosmosAsyncContainer, cosmosWriteConfig: CosmosWrit
private val taskDiagnosticsContext = SparkTaskContext(diagnosticsContext.correlationActivityId,
taskContext.stageId(),
taskContext.partitionId(),
taskContext.taskAttemptId(),
"PointWriter")

override def scheduleWrite(partitionKeyValue: PartitionKey, objectNode: ObjectNode): Unit = {
Expand Down Expand Up @@ -287,9 +288,11 @@ class PointWriter(container: CosmosAsyncContainer, cosmosWriteConfig: CosmosWrit
private def getOptions(): CosmosItemRequestOptions = {
val options = new CosmosItemRequestOptions()
if (diagnosticsConfig.mode.isDefined) {
val taskDiagnosticsContext = SparkTaskContext(diagnosticsContext.correlationActivityId,
val taskDiagnosticsContext = SparkTaskContext(
diagnosticsContext.correlationActivityId,
taskContext.stageId(),
taskContext.partitionId(),
taskContext.taskAttemptId(),
"")

val listener: OperationListener =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,15 @@ private[spark] case class DeleteOperation(sparkTaskContext: SparkTaskContext, it
private[spark] case class SparkTaskContext(correlationActivityId: String,
stageId: Int,
partitionId: Long,
taskAttemptId: Long,
details: String) extends OperationContext {

@transient private lazy val cachedToString = {
"SparkTaskContext(" +
"correlationActivityId=" + correlationActivityId +
",stageId=" + stageId +
",partitionId=" + partitionId +
",taskAttemptId=" + taskAttemptId +
",details=" + details + ")"
}

Expand Down
Loading

0 comments on commit ddafb5f

Please sign in to comment.