Skip to content

Commit

Permalink
Add support for parallel reader and writer for ShardReplicationTask
Browse files Browse the repository at this point in the history
  • Loading branch information
ankitkala committed Jul 13, 2021
1 parent a08ab83 commit 0afea22
Show file tree
Hide file tree
Showing 7 changed files with 202 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ internal class ReplicationPlugin : Plugin(), ActionPlugin, PersistentTaskPlugin,
const val REPLICATION_EXECUTOR_NAME_FOLLOWER = "replication_follower"
val REPLICATED_INDEX_SETTING: Setting<String> = Setting.simpleString("index.plugins.replication.replicated",
Setting.Property.InternalIndex, Setting.Property.IndexScope)
val REPLICATION_CHANGE_BATCH_SIZE: Setting<Int> = Setting.intSetting("plugins.replication.ops_batch_size", 512, 16,
val REPLICATION_CHANGE_BATCH_SIZE: Setting<Int> = Setting.intSetting("plugins.replication.ops_batch_size", 50000, 16,
Setting.Property.Dynamic, Setting.Property.NodeScope)
val REPLICATION_LEADER_THREADPOOL_SIZE: Setting<Int> = Setting.intSetting("plugins.replication.leader.size", 0, 0,
Setting.Property.Dynamic, Setting.Property.NodeScope)
Expand All @@ -152,6 +152,10 @@ internal class ReplicationPlugin : Plugin(), ActionPlugin, PersistentTaskPlugin,
Setting.Property.Dynamic, Setting.Property.NodeScope)
val REPLICATION_FOLLOWER_RECOVERY_PARALLEL_CHUNKS: Setting<Int> = Setting.intSetting("plugins.replication.index.recovery.max_concurrent_file_chunks", 5, 1,
Setting.Property.Dynamic, Setting.Property.NodeScope)
val REPLICATION_PARALLEL_READ_PER_SHARD = Setting.intSetting("plugins.replication.parallel_reads_per_shard", 2, 1,
Setting.Property.Dynamic, Setting.Property.NodeScope)
val REPLICATION_PARALLEL_READ_POLL_DURATION = Setting.timeSetting ("plugins.replication.parallel_reads_poll_duration", TimeValue.timeValueMillis(50), TimeValue.timeValueMillis(1),
TimeValue.timeValueSeconds(1), Setting.Property.Dynamic, Setting.Property.NodeScope)
}

override fun createComponents(client: Client, clusterService: ClusterService, threadPool: ThreadPool,
Expand Down Expand Up @@ -296,8 +300,8 @@ internal class ReplicationPlugin : Plugin(), ActionPlugin, PersistentTaskPlugin,
}

override fun getSettings(): List<Setting<*>> {
return listOf(REPLICATED_INDEX_SETTING, REPLICATION_CHANGE_BATCH_SIZE,
REPLICATION_LEADER_THREADPOOL_SIZE, REPLICATION_LEADER_THREADPOOL_QUEUE_SIZE,
return listOf(REPLICATED_INDEX_SETTING, REPLICATION_CHANGE_BATCH_SIZE, REPLICATION_LEADER_THREADPOOL_SIZE,
REPLICATION_LEADER_THREADPOOL_QUEUE_SIZE, REPLICATION_PARALLEL_READ_PER_SHARD,
REPLICATION_FOLLOWER_RECOVERY_CHUNK_SIZE, REPLICATION_FOLLOWER_RECOVERY_PARALLEL_CHUNKS)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,16 @@ import org.elasticsearch.index.translog.Translog

class GetChangesResponse(val changes: List<Translog.Operation>,
val fromSeqNo: Long,
val maxSeqNoOfUpdatesOrDeletes: Long) : ActionResponse() {
val maxSeqNoOfUpdatesOrDeletes: Long,
val lastSyncedGlobalCheckpoint: Long) : ActionResponse() {

constructor(inp: StreamInput) : this(inp.readList(Translog.Operation::readOperation), inp.readVLong(), inp.readLong())
constructor(inp: StreamInput) : this(inp.readList(Translog.Operation::readOperation), inp.readVLong(),
inp.readLong(), inp.readLong())

override fun writeTo(out: StreamOutput) {
out.writeCollection(changes, Translog.Operation::writeOperation)
out.writeVLong(fromSeqNo)
out.writeLong(maxSeqNoOfUpdatesOrDeletes)
out.writeLong(lastSyncedGlobalCheckpoint)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ class TransportGetChangesAction @Inject constructor(threadPool: ThreadPool, clus
}
}
}
GetChangesResponse(ops, request.fromSeqNo, indexShard.maxSeqNoOfUpdatesOrDeletes)
GetChangesResponse(ops, request.fromSeqNo, indexShard.maxSeqNoOfUpdatesOrDeletes, indexShard.lastSyncedGlobalCheckpoint)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package com.amazon.elasticsearch.replication.task.shard

import com.amazon.elasticsearch.replication.ReplicationPlugin
import kotlinx.coroutines.delay
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import org.elasticsearch.cluster.service.ClusterService
import org.elasticsearch.common.logging.Loggers
import org.elasticsearch.common.unit.TimeValue
import org.elasticsearch.index.shard.IndexShard
import java.util.*
import java.util.concurrent.atomic.AtomicLong
import kotlin.collections.ArrayList

/**
* Since we have added support for fetching batch of operations in parallel, we need to keep track of
* how many operations have been fetched and what batch needs to be fetched next. This creates the
* problem of concurrency with shared mutable state (https://kotlinlang.org/docs/shared-mutable-state-and-concurrency.html).
* ShardReplicationChangesTracker abstracts away all that complexity from ShardReplicationTask.
* Every reader coroutine in a shard has to interact with the tracker for:
* 1. Requesting the range of operations to be fetched in the batch.
* 2. Updating the final status of the batch fetch.
*/
class ShardReplicationChangesTracker(clusterService: ClusterService, indexShard: IndexShard) {
private val log = Loggers.getLogger(javaClass, indexShard.shardId())!!

private val mutex = Mutex()
private val missingBatches = Collections.synchronizedList(ArrayList<Pair<Long, Long>>())
private val observedSeqNoAtLeader = AtomicLong(indexShard.localCheckpoint)
private val seqNoAlreadyRequested = AtomicLong(indexShard.localCheckpoint)

@Volatile private var batchSize = clusterService.clusterSettings.get(ReplicationPlugin.REPLICATION_CHANGE_BATCH_SIZE)
@Volatile private var pollDuration: TimeValue = clusterService.clusterSettings.get(ReplicationPlugin.REPLICATION_PARALLEL_READ_POLL_DURATION)

init {
clusterService.clusterSettings.addSettingsUpdateConsumer(ReplicationPlugin.REPLICATION_CHANGE_BATCH_SIZE) { batchSize = it }
clusterService.clusterSettings.addSettingsUpdateConsumer(ReplicationPlugin.REPLICATION_PARALLEL_READ_POLL_DURATION) { pollDuration = it }
}

/**
* Provides a range of operations to be fetched next.
*
* Here are the guarantees that this method provides:
* 1. All reader coroutines get unique range of operations to fetch.
* 2. It'll ensure that the complete range of operations would be fetched.
* 3. Mutex in this method ensures that only one coroutine is requesting the batch at a time.
* If there are multiple coroutines, they'll be waiting in order to get the range of operations to fetch.
* 4. If we've already fetched all the operations from leader, there would be one and only one
* reader polling on leader per shard.
*/
suspend fun requestBatchToFetch():Pair<Long, Long> {
mutex.withLock {
logDebug("Waiting to get batch. requested: ${seqNoAlreadyRequested.get()}, leader: ${observedSeqNoAtLeader.get()}")

// Wait till we have batch to fetch. Note that if seqNoAlreadyRequested is equal to observedSeqNoAtLeader,
// we still should be sending one more request to fetch which will just do a poll and eventually timeout
// if no new operations are there on the leader (configured via TransportGetChangesAction.WAIT_FOR_NEW_OPS_TIMEOUT)
while (seqNoAlreadyRequested.get() > observedSeqNoAtLeader.get() && missingBatches.isEmpty()) {
delay(pollDuration.millis)
}

// missing batch takes higher priority.
return if (missingBatches.isNotEmpty()) {
logDebug("Fetching missing batch ${missingBatches[0].first}-${missingBatches[0].second}")
missingBatches.removeAt(0)
} else {
// return the next batch to fetch and update seqNoAlreadyRequested.
val fromSeq = seqNoAlreadyRequested.getAndAdd(batchSize.toLong()) + 1
val toSeq = fromSeq + batchSize - 1
logDebug("Fetching the batch $fromSeq-$toSeq")
Pair(fromSeq, toSeq)
}
}
}

/**
* Ensures that we've successfully fetched a particular range of operations.
* In case of any failure(or we didn't get complete batch), we make sure that we're fetching the
* missing operations in the next batch.
*/
fun updateBatchFetched(success: Boolean, fromSeqNoRequested: Long, toSeqNoRequested: Long, toSeqNoReceived: Long, seqNoAtLeader: Long) {
if (success) {
// we shouldn't ever be getting more operations than requested.
assert(toSeqNoRequested >= toSeqNoReceived) { "${Thread.currentThread().getName()} Got more operations in the batch than requested" }
logDebug("Updating the batch fetched. ${fromSeqNoRequested}-${toSeqNoReceived}/${toSeqNoRequested}, seqNoAtLeader:$seqNoAtLeader")

// If we didn't get the complete batch that we had requested.
if (toSeqNoRequested > toSeqNoReceived) {
// If this is the last batch being fetched, update the seqNoAlreadyRequested.
if (seqNoAlreadyRequested.get() == toSeqNoRequested) {
seqNoAlreadyRequested.updateAndGet { toSeqNoReceived }
} else {
// Else, add to the missing operations to missing batch
logDebug("Didn't get the complete batch. Adding the missing operations ${toSeqNoReceived + 1}-${toSeqNoRequested}")
missingBatches.add(Pair(toSeqNoReceived + 1, toSeqNoRequested))
}
}

// Update the sequence number observed at leader.
observedSeqNoAtLeader.getAndUpdate { value -> if (seqNoAtLeader > value) seqNoAtLeader else value }
logDebug("observedSeqNoAtLeader: ${observedSeqNoAtLeader.get()}")
} else {
// If this is the last batch being fetched, update the seqNoAlreadyRequested.
if (seqNoAlreadyRequested.get() == toSeqNoRequested) {
seqNoAlreadyRequested.updateAndGet { fromSeqNoRequested - 1 }
} else {
// If this was not the last batch, we might have already fetched other batch of
// operations after this. Adding this to missing.
logDebug("Adding batch to missing $fromSeqNoRequested-$toSeqNoRequested")
missingBatches.add(Pair(fromSeqNoRequested, toSeqNoRequested))
}
}
}

private fun logDebug(msg: String) {
log.debug("${Thread.currentThread().name}: $msg")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package com.amazon.elasticsearch.replication.task.shard

import com.amazon.elasticsearch.replication.ReplicationException
import com.amazon.elasticsearch.replication.ReplicationPlugin.Companion.REPLICATION_CHANGE_BATCH_SIZE
import com.amazon.elasticsearch.replication.ReplicationPlugin.Companion.REPLICATION_PARALLEL_READ_PER_SHARD
import com.amazon.elasticsearch.replication.action.changes.GetChangesAction
import com.amazon.elasticsearch.replication.action.changes.GetChangesRequest
import com.amazon.elasticsearch.replication.action.changes.GetChangesResponse
Expand All @@ -35,8 +36,10 @@ import com.amazon.elasticsearch.replication.util.suspendExecuteWithRetries
import com.amazon.elasticsearch.replication.util.suspending
import kotlinx.coroutines.ObsoleteCoroutinesApi
import kotlinx.coroutines.cancel
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.isActive
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Semaphore
import org.elasticsearch.ElasticsearchException
import org.elasticsearch.ElasticsearchTimeoutException
Expand Down Expand Up @@ -77,16 +80,19 @@ class ShardReplicationTask(id: Long, type: String, action: String, description:

private val clusterStateListenerForTaskInterruption = ClusterStateListenerForTaskInterruption()

// Since these setting are not dynamic, these settings update would only reflect after pause-resume or on a new replication job.
@Volatile private var batchSize = clusterService.clusterSettings.get(REPLICATION_CHANGE_BATCH_SIZE)
@Volatile private var readersPerShard = clusterService.clusterSettings.get(REPLICATION_PARALLEL_READ_PER_SHARD)

init {
clusterService.clusterSettings.addSettingsUpdateConsumer(REPLICATION_CHANGE_BATCH_SIZE) { batchSize = it }
clusterService.clusterSettings.addSettingsUpdateConsumer(REPLICATION_PARALLEL_READ_PER_SHARD) { readersPerShard = it }
}

override val log = Loggers.getLogger(javaClass, followerShardId)!!

companion object {
fun taskIdForShard(shardId: ShardId) = "replication:${shardId}"
const val CONCURRENT_REQUEST_RATE_LIMIT = 10
}

@ObsoleteCoroutinesApi
Expand Down Expand Up @@ -142,52 +148,71 @@ class ShardReplicationTask(id: Long, type: String, action: String, description:

addListenerToInterruptTask()

// Not really used yet as we only have one get changes action at a time.
val rateLimiter = Semaphore(CONCURRENT_REQUEST_RATE_LIMIT)
var seqNo = indexShard.localCheckpoint + 1
val rateLimiter = Semaphore(readersPerShard)
val sequencer = TranslogSequencer(scope, replicationMetadata, followerShardId, remoteCluster, remoteShardId.indexName,
TaskId(clusterService.nodeName, id), client, rateLimiter, seqNo - 1)

// TODO: Redesign this to avoid sharing the rateLimiter between this block and the sequencer.
// This was done as a stopgap to work around a concurrency bug that needed to be fixed fast.
while (scope.isActive) {
rateLimiter.acquire()
try {
val changesResponse = getChanges(seqNo)
log.info("Got ${changesResponse.changes.size} changes starting from seqNo: $seqNo")
sequencer.send(changesResponse)
seqNo = changesResponse.changes.lastOrNull()?.seqNo()?.inc() ?: seqNo
} catch (e: ElasticsearchTimeoutException) {
log.info("Timed out waiting for new changes. Current seqNo: $seqNo")
rateLimiter.release()
continue
} catch (e: NodeNotConnectedException) {
log.info("Node not connected. Retrying request using a different node. $e")
delay(backOffForNodeDiscovery)
rateLimiter.release()
continue
}
//renew retention lease with global checkpoint so that any shard that picks up shard replication task has data until then.
try {
retentionLeaseHelper.renewRetentionLease(remoteShardId, indexShard.lastSyncedGlobalCheckpoint, followerShardId)
} catch (ex: Exception) {
when (ex) {
is RetentionLeaseInvalidRetainingSeqNoException, is RetentionLeaseNotFoundException -> {
throw ex
TaskId(clusterService.nodeName, id), client, indexShard.localCheckpoint)

val changeTracker = ShardReplicationChangesTracker(clusterService, indexShard)

coroutineScope {
while (scope.isActive) {
rateLimiter.acquire()
launch {
logDebug("Spawning the reader")
val batchToFetch = changeTracker.requestBatchToFetch()
val fromSeqNo = batchToFetch.first
val toSeqNo = batchToFetch.second
try {
logDebug("Getting changes $fromSeqNo-$toSeqNo")
val changesResponse = getChanges(fromSeqNo, toSeqNo)
logInfo("Got ${changesResponse.changes.size} changes starting from seqNo: $fromSeqNo")
sequencer.send(changesResponse)
logDebug("pushed to sequencer $fromSeqNo-$toSeqNo")
changeTracker.updateBatchFetched(true, fromSeqNo, toSeqNo, changesResponse.changes.lastOrNull()?.seqNo() ?: fromSeqNo - 1,
changesResponse.lastSyncedGlobalCheckpoint)
} catch (e: ElasticsearchTimeoutException) {
logInfo("Timed out waiting for new changes. Current seqNo: $fromSeqNo")
changeTracker.updateBatchFetched(false, fromSeqNo, toSeqNo, fromSeqNo - 1,-1)
} catch (e: NodeNotConnectedException) {
logInfo("Node not connected. Retrying request using a different node. $e")
delay(backOffForNodeDiscovery)
changeTracker.updateBatchFetched(false, fromSeqNo, toSeqNo, fromSeqNo - 1,-1)
} catch (e: Exception) {
logInfo("Unable to get changes from seqNo: $fromSeqNo. $e")
changeTracker.updateBatchFetched(false, fromSeqNo, toSeqNo, fromSeqNo - 1,-1)
} finally {
rateLimiter.release()
}
}

//renew retention lease with global checkpoint so that any shard that picks up shard replication task has data until then.
try {
retentionLeaseHelper.renewRetentionLease(remoteShardId, indexShard.lastSyncedGlobalCheckpoint, followerShardId)
} catch (ex: Exception) {
when (ex) {
is RetentionLeaseInvalidRetainingSeqNoException, is RetentionLeaseNotFoundException -> {
throw ex
}
else -> log.info("Exception renewing retention lease. Not an issue", ex);
}
else -> log.info("Exception renewing retention lease. Not an issue", ex);
}
}
}
sequencer.close()
}

private suspend fun getChanges(fromSeqNo: Long): GetChangesResponse {
private suspend fun getChanges(fromSeqNo: Long, toSeqNo: Long): GetChangesResponse {
val remoteClient = client.getRemoteClusterClient(remoteCluster)
val request = GetChangesRequest(remoteShardId, fromSeqNo, fromSeqNo + batchSize)
val request = GetChangesRequest(remoteShardId, fromSeqNo, toSeqNo)
return remoteClient.suspendExecuteWithRetries(replicationMetadata = replicationMetadata,
action = GetChangesAction.INSTANCE, req = request, log = log)
}
private fun logDebug(msg: String) {
log.debug("${Thread.currentThread().name}: $msg")
}
private fun logInfo(msg: String) {
log.info("${Thread.currentThread().name}: $msg")
}

override fun toString(): String {
return "ShardReplicationTask(from=${remoteCluster}$remoteShardId to=$followerShardId)"
Expand Down
Loading

0 comments on commit 0afea22

Please sign in to comment.