Skip to content

Commit

Permalink
Add new background task to fail stale replica shards.
Browse files Browse the repository at this point in the history
Signed-off-by: Rishikesh1159 <[email protected]>
  • Loading branch information
Rishikesh1159 committed Mar 28, 2023
1 parent 7500270 commit bf9b3dc
Show file tree
Hide file tree
Showing 4 changed files with 183 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.IndexShardState;
import org.opensearch.indices.replication.SegmentReplicationBaseIT;
import org.opensearch.plugins.Plugin;
import org.opensearch.rest.RestStatus;
Expand All @@ -28,6 +29,7 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static java.util.Arrays.asList;
Expand Down Expand Up @@ -179,6 +181,46 @@ public void testBelowReplicaLimit() throws Exception {
verifyStoreContent();
}

public void testFailStaleReplica() throws Exception {

// Starts a primary and replica node.
final String primaryNode = internalCluster().startNode(
Settings.builder().put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(500)).build()
);
createIndex(INDEX_NAME);
ensureYellowAndNoInitializingShards(INDEX_NAME);
final String replicaNode = internalCluster().startNode(
Settings.builder().put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(500)).build()
);
ensureGreen(INDEX_NAME);

final IndexShard primaryShard = getIndexShard(primaryNode, INDEX_NAME);
final List<String> replicaNodes = asList(replicaNode);
assertEqualSegmentInfosVersion(replicaNodes, primaryShard);
IndexShard replicaShard = getIndexShard(replicaNode, INDEX_NAME);

final CountDownLatch latch = new CountDownLatch(1);
final AtomicInteger totalDocs = new AtomicInteger(0);
try (final Releasable ignored = blockReplication(replicaNodes, latch)) {
// Index docs until we replicas are staled.
Thread indexingThread = new Thread(() -> { totalDocs.getAndSet(indexUntilCheckpointCount()); });
indexingThread.start();
indexingThread.join();
latch.await();
// index again while we are stale.
indexDoc();
totalDocs.incrementAndGet();

// Verify that replica shard is closed.
assertBusy(() -> { assertTrue(replicaShard.state().equals(IndexShardState.CLOSED)); }, 1, TimeUnit.MINUTES);
}
ensureGreen(INDEX_NAME);
final IndexShard replicaAfterFailure = getIndexShard(replicaNode, INDEX_NAME);

// Verify that new replica shard after failure is different from old replica shard.
assertNotEquals(replicaAfterFailure.routingEntry().allocationId().getId(), replicaShard.routingEntry().allocationId().getId());
}

public void testBulkWritesRejected() throws Exception {
final String primaryNode = internalCluster().startNode();
createIndex(INDEX_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,33 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.cluster.action.shard.ShardStateAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AbstractAsyncTask;
import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.ShardId;
import org.opensearch.indices.IndicesService;
import org.opensearch.threadpool.ThreadPool;

import java.io.Closeable;
import java.io.IOException;
import java.util.Set;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Service responsible for applying backpressure for lagging behind replicas when Segment Replication is enabled.
*
* @opensearch.internal
*/
public class SegmentReplicationPressureService {
public class SegmentReplicationPressureService implements Closeable {

private volatile boolean isSegmentReplicationBackpressureEnabled;
private volatile int maxCheckpointsBehind;
Expand Down Expand Up @@ -70,12 +77,30 @@ public class SegmentReplicationPressureService {
);

private final IndicesService indicesService;

private final ThreadPool threadPool;
private final SegmentReplicationStatsTracker tracker;

private final ShardStateAction shardStateAction;

public AsyncFailStaleReplicaTask getFailStaleReplicaTask() {
return failStaleReplicaTask;
}

private volatile AsyncFailStaleReplicaTask failStaleReplicaTask;

@Inject
public SegmentReplicationPressureService(Settings settings, ClusterService clusterService, IndicesService indicesService) {
public SegmentReplicationPressureService(
Settings settings,
ClusterService clusterService,
IndicesService indicesService,
ShardStateAction shardStateAction,
ThreadPool threadPool
) {
this.indicesService = indicesService;
this.tracker = new SegmentReplicationStatsTracker(this.indicesService);
this.shardStateAction = shardStateAction;
this.threadPool = threadPool;

final ClusterSettings clusterSettings = clusterService.getClusterSettings();
this.isSegmentReplicationBackpressureEnabled = SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.get(settings);
Expand All @@ -92,6 +117,8 @@ public SegmentReplicationPressureService(Settings settings, ClusterService clust

this.maxAllowedStaleReplicas = MAX_ALLOWED_STALE_SHARDS.get(settings);
clusterSettings.addSettingsUpdateConsumer(MAX_ALLOWED_STALE_SHARDS, this::setMaxAllowedStaleReplicas);

this.failStaleReplicaTask = new AsyncFailStaleReplicaTask(TimeValue.timeValueMillis(1));
}

public void isSegrepLimitBreached(ShardId shardId) {
Expand Down Expand Up @@ -154,4 +181,68 @@ public void setMaxAllowedStaleReplicas(double maxAllowedStaleReplicas) {
public void setMaxReplicationTime(TimeValue maxReplicationTime) {
this.maxReplicationTime = maxReplicationTime;
}

@Override
public void close() throws IOException {
failStaleReplicaTask.close();
}

// Background Task to fail replica shards if they are too far behind primary shard.
final class AsyncFailStaleReplicaTask extends AbstractAsyncTask {

AsyncFailStaleReplicaTask(TimeValue interval) {
super(logger, threadPool, interval, true);
rescheduleIfNecessary();
}

@Override
protected boolean mustReschedule() {
return true;
}

@Override
protected void runInternal() {
final SegmentReplicationStats stats = tracker.getStats();
for (Map.Entry<ShardId, SegmentReplicationPerGroupStats> entry : stats.getShardStats().entrySet()) {
final Set<SegmentReplicationShardStats> staleReplicas = getStaleReplicas(entry.getValue().getReplicaStats());
final ShardId shardId = entry.getKey();
final IndexService indexService = indicesService.indexService(shardId.getIndex());
final IndexShard primaryShard = indexService.getShard(shardId.getId());
for (SegmentReplicationShardStats staleReplica : staleReplicas) {
if (staleReplica.getCurrentReplicationTimeMillis() > 2 * maxReplicationTime.millis()) {
shardStateAction.remoteShardFailed(
shardId,
staleReplica.getAllocationId(),
primaryShard.getOperationPrimaryTerm(),
true,
"replica too far behind primary, marking as stale",
null,
new ActionListener<>() {
@Override
public void onResponse(Void unused) {
logger.trace(
"Successfully failed remote shardId [{}] allocation id [{}]",
shardId,
staleReplica.getAllocationId()
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to send remote shard failure", e);
}
}
);
}
}
}
}

@Override
public String toString() {
return "fail_stale_replica";
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

package org.opensearch.index;

import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import org.opensearch.cluster.action.shard.ShardStateAction;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
Expand All @@ -21,6 +23,7 @@
import org.opensearch.index.shard.ShardId;
import org.opensearch.indices.IndicesService;
import org.opensearch.indices.replication.common.ReplicationType;
import org.opensearch.threadpool.ThreadPool;

import java.util.Iterator;
import java.util.List;
Expand All @@ -29,13 +32,20 @@
import java.util.concurrent.TimeUnit;

import static java.util.Arrays.asList;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.times;
import static org.opensearch.index.SegmentReplicationPressureService.MAX_REPLICATION_TIME_SETTING;
import static org.opensearch.index.SegmentReplicationPressureService.SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED;

public class SegmentReplicationPressureServiceTests extends OpenSearchIndexLevelReplicationTestCase {

private static ShardStateAction shardStateAction = Mockito.mock(ShardStateAction.class);
private static final Settings settings = Settings.builder()
.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
.put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
Expand Down Expand Up @@ -181,6 +191,36 @@ public void testIsSegrepLimitBreached_underStaleNodeLimit() throws Exception {
}
}

public void testFailStaleReplicaTask() throws Exception {
final Settings settings = Settings.builder()
.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
.put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
.put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(10))
.build();

try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) {
shards.startAll();
final IndexShard primaryShard = shards.getPrimary();
SegmentReplicationPressureService service = buildPressureService(settings, primaryShard);

// index docs in batches without refreshing
indexInBatches(5, shards, primaryShard);

// assert that replica shard is few checkpoints behind primary
Set<SegmentReplicationShardStats> replicationStats = primaryShard.getReplicationStats();
assertEquals(1, replicationStats.size());
SegmentReplicationShardStats shardStats = replicationStats.stream().findFirst().get();
assertEquals(5, shardStats.getCheckpointsBehindCount());

// call the background task
service.getFailStaleReplicaTask().runInternal();

// verify that remote shard failed method is called which fails the replica shards falling behind.
verify(shardStateAction, times(1)).remoteShardFailed(any(), anyString(), anyLong(), anyBoolean(), anyString(), any(), any());
replicateSegments(primaryShard, shards.getReplicas());
}
}

private int indexInBatches(int count, ReplicationGroup shards, IndexShard primaryShard) throws Exception {
int totalDocs = 0;
for (int i = 0; i < count; i++) {
Expand All @@ -202,6 +242,6 @@ private SegmentReplicationPressureService buildPressureService(Settings settings
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS));

return new SegmentReplicationPressureService(settings, clusterService, indicesService);
return new SegmentReplicationPressureService(settings, clusterService, indicesService, shardStateAction, mock(ThreadPool.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1980,7 +1980,13 @@ public void onFailure(final Exception e) {
new UpdateHelper(scriptService),
actionFilters,
new IndexingPressureService(settings, clusterService),
new SegmentReplicationPressureService(settings, clusterService, mock(IndicesService.class)),
new SegmentReplicationPressureService(
settings,
clusterService,
mock(IndicesService.class),
mock(ShardStateAction.class),
mock(ThreadPool.class)
),
new SystemIndices(emptyMap())
);
actions.put(
Expand Down

0 comments on commit bf9b3dc

Please sign in to comment.