Skip to content

Commit

Permalink
Fix for missing ShardReplicationTasks on new nodes (opensearch-projec…
Browse files Browse the repository at this point in the history
…t#497)

Signed-off-by: Ankit Kala <[email protected]>

Signed-off-by: Ankit Kala <[email protected]>
  • Loading branch information
ankitkala authored Aug 26, 2022
1 parent 8f0a55c commit 805f686
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
ReplicationState.INIT_FOLLOW -> {
log.info("Starting shard tasks")
addIndexBlockForReplication()
startShardFollowTasks(emptyMap())
FollowingState(startNewOrMissingShardTasks())

}
ReplicationState.FOLLOWING -> {
if (currentTaskState is FollowingState) {
Expand All @@ -206,8 +207,8 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
// Tasks need to be started
state
} else {
state = pollShardTaskStatus((followingTaskState as FollowingState).shardReplicationTasks)
followingTaskState = startMissingShardTasks((followingTaskState as FollowingState).shardReplicationTasks)
state = pollShardTaskStatus()
followingTaskState = FollowingState(startNewOrMissingShardTasks())
when (state) {
is MonitoringState -> {
updateMetadata()
Expand Down Expand Up @@ -285,24 +286,7 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
clusterService.addListener(this)
}

private suspend fun startMissingShardTasks(shardTasks: Map<ShardId, PersistentTask<ShardReplicationParams>>): IndexReplicationState {
val persistentTasks = clusterService.state().metadata.custom<PersistentTasksCustomMetadata>(PersistentTasksCustomMetadata.TYPE)

val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()
.map { task -> task.params as ShardReplicationParams }
.collect(Collectors.toList())

val runningTasksForCurrentIndex = shardTasks.filter { entry -> runningShardTasks.find { task -> task.followerShardId == entry.key } != null}

val numMissingTasks = shardTasks.size - runningTasksForCurrentIndex.size
if (numMissingTasks > 0) {
log.info("Starting $numMissingTasks missing shard task(s)")
return startShardFollowTasks(runningTasksForCurrentIndex)
}
return FollowingState(shardTasks)
}

private suspend fun pollShardTaskStatus(shardTasks: Map<ShardId, PersistentTask<ShardReplicationParams>>): IndexReplicationState {
private suspend fun pollShardTaskStatus(): IndexReplicationState {
val failedShardTasks = findAllReplicationFailedShardTasks(followerIndexName, clusterService.state())
if (failedShardTasks.isNotEmpty()) {
log.info("Failed shard tasks - ", failedShardTasks)
Expand Down Expand Up @@ -343,11 +327,16 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
registerCloseListeners()
val clusterState = clusterService.state()
val persistentTasks = clusterState.metadata.custom<PersistentTasksCustomMetadata>(PersistentTasksCustomMetadata.TYPE)
val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()

val followerShardIds = clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards()
.map { shard -> shard.value.shardId }
.stream().collect(Collectors.toSet())
val runningShardTasksForIndex = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()
.map { task -> task.params as ShardReplicationParams }
.filter {taskParam -> followerShardIds.contains(taskParam.followerShardId) }
.collect(Collectors.toList())

if (runningShardTasks.size == 0) {
if (runningShardTasksForIndex.size != followerShardIds.size) {
return InitFollowState
}

Expand Down Expand Up @@ -696,19 +685,27 @@ open class IndexReplicationTask(id: Long, type: String, action: String, descript
}
}

private suspend fun
startShardFollowTasks(tasks: Map<ShardId, PersistentTask<ShardReplicationParams>>): FollowingState {
suspend fun startNewOrMissingShardTasks(): Map<ShardId, PersistentTask<ShardReplicationParams>> {
assert(clusterService.state().routingTable.hasIndex(followerIndexName)) { "Can't find index $followerIndexName" }
val shards = clusterService.state().routingTable.indicesRouting().get(followerIndexName).shards()
val newTasks = shards.map {
val persistentTasks = clusterService.state().metadata.custom<PersistentTasksCustomMetadata>(PersistentTasksCustomMetadata.TYPE)
val runningShardTasks = persistentTasks.findTasks(ShardReplicationExecutor.TASK_NAME, Predicate { true }).stream()
.map { task -> task as PersistentTask<ShardReplicationParams> }
.filter { task -> task.params!!.followerShardId.indexName == followerIndexName}
.collect(Collectors.toMap(
{t: PersistentTask<ShardReplicationParams> -> t.params!!.followerShardId},
{t: PersistentTask<ShardReplicationParams> -> t}))

val tasks = shards.map {
it.value.shardId
}.associate { shardId ->
val task = tasks.getOrElse(shardId) {
val task = runningShardTasks.getOrElse(shardId) {
startReplicationTask(ShardReplicationParams(leaderAlias, ShardId(leaderIndex, shardId.id), shardId))
}
return@associate shardId to task
}
return FollowingState(newTasks)

return tasks
}

private suspend fun cancelRestore() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ import org.opensearch.tasks.TaskManager
import org.opensearch.test.ClusterServiceUtils
import org.opensearch.test.ClusterServiceUtils.setState
import org.opensearch.test.OpenSearchTestCase
import org.opensearch.test.OpenSearchTestCase.assertBusy
import org.opensearch.threadpool.TestThreadPool
import java.util.*
import java.util.concurrent.TimeUnit
Expand Down Expand Up @@ -150,6 +149,66 @@ class IndexReplicationTaskTests : OpenSearchTestCase() {

}

fun testStartNewShardTasks() = runBlocking {
val replicationTask: IndexReplicationTask = spy(createIndexReplicationTask())
var taskManager = Mockito.mock(TaskManager::class.java)
replicationTask.setPersistent(taskManager)
var rc = ReplicationContext(followerIndex)
var rm = ReplicationMetadata(connectionName, ReplicationStoreMetadataType.INDEX.name, ReplicationOverallState.RUNNING.name, "reason", rc, rc, Settings.EMPTY)
replicationTask.setReplicationMetadata(rm)

// Build cluster state
val indices: MutableList<String> = ArrayList()
indices.add(followerIndex)
var metadata = Metadata.builder()
.put(IndexMetadata.builder(REPLICATION_CONFIG_SYSTEM_INDEX).settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
.put(IndexMetadata.builder(followerIndex).settings(settings(Version.CURRENT)).numberOfShards(2).numberOfReplicas(0))
.build()
var routingTableBuilder = RoutingTable.builder()
.addAsNew(metadata.index(REPLICATION_CONFIG_SYSTEM_INDEX))
.addAsNew(metadata.index(followerIndex))
var newClusterState = ClusterState.builder(clusterService.state()).routingTable(routingTableBuilder.build()).build()
setState(clusterService, newClusterState)

// Try starting shard tasks
val shardTasks = replicationTask.startNewOrMissingShardTasks()
assertThat(shardTasks.size == 2).isTrue
}


fun testStartMissingShardTasks() = runBlocking {
val replicationTask: IndexReplicationTask = spy(createIndexReplicationTask())
var taskManager = Mockito.mock(TaskManager::class.java)
replicationTask.setPersistent(taskManager)
var rc = ReplicationContext(followerIndex)
var rm = ReplicationMetadata(connectionName, ReplicationStoreMetadataType.INDEX.name, ReplicationOverallState.RUNNING.name, "reason", rc, rc, Settings.EMPTY)
replicationTask.setReplicationMetadata(rm)

// Build cluster state
val indices: MutableList<String> = ArrayList()
indices.add(followerIndex)

val tasks = PersistentTasksCustomMetadata.builder()
var sId = ShardId(Index(followerIndex, "_na_"), 0)
tasks.addTask<PersistentTaskParams>( "replication:0", ShardReplicationExecutor.TASK_NAME, ShardReplicationParams("remoteCluster", sId, sId),
PersistentTasksCustomMetadata.Assignment("other_node_", "test assignment on other node"))

var metadata = Metadata.builder()
.put(IndexMetadata.builder(REPLICATION_CONFIG_SYSTEM_INDEX).settings(settings(Version.CURRENT)).numberOfShards(1).numberOfReplicas(0))
.put(IndexMetadata.builder(followerIndex).settings(settings(Version.CURRENT)).numberOfShards(2).numberOfReplicas(0))
.putCustom(PersistentTasksCustomMetadata.TYPE, tasks.build())
.build()
var routingTableBuilder = RoutingTable.builder()
.addAsNew(metadata.index(REPLICATION_CONFIG_SYSTEM_INDEX))
.addAsNew(metadata.index(followerIndex))
var newClusterState = ClusterState.builder(clusterService.state()).routingTable(routingTableBuilder.build()).build()
setState(clusterService, newClusterState)

// Try starting shard tasks
val shardTasks = replicationTask.startNewOrMissingShardTasks()
assertThat(shardTasks.size == 2).isTrue
}

private fun createIndexReplicationTask() : IndexReplicationTask {
var threadPool = TestThreadPool("IndexReplicationTask")
//Hack Alert : Though it is meant to force rejection , this is to make overallTaskScope not null
Expand Down

0 comments on commit 805f686

Please sign in to comment.