Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Segment Replication] Add new background task to fail stale replica shards. #6850

Merged
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.IndexShardState;
import org.opensearch.indices.replication.SegmentReplicationBaseIT;
import org.opensearch.indices.replication.common.ReplicationType;
import org.opensearch.plugins.Plugin;
Expand All @@ -29,6 +30,7 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static java.util.Arrays.asList;
Expand Down Expand Up @@ -200,6 +202,42 @@ public void testBelowReplicaLimit() throws Exception {
verifyStoreContent();
}

public void testFailStaleReplica() throws Exception {

Settings settings = Settings.builder().put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(500)).build();
// Starts a primary and replica node.
final String primaryNode = internalCluster().startNode(settings);
createIndex(INDEX_NAME);
ensureYellowAndNoInitializingShards(INDEX_NAME);
final String replicaNode = internalCluster().startNode(settings);
ensureGreen(INDEX_NAME);

final IndexShard primaryShard = getIndexShard(primaryNode, INDEX_NAME);
final List<String> replicaNodes = asList(replicaNode);
assertEqualSegmentInfosVersion(replicaNodes, primaryShard);
IndexShard replicaShard = getIndexShard(replicaNode, INDEX_NAME);

final CountDownLatch latch = new CountDownLatch(1);
final AtomicInteger totalDocs = new AtomicInteger(0);
try (final Releasable ignored = blockReplication(replicaNodes, latch)) {
// Index docs until replicas are staled.
totalDocs.getAndSet(indexUntilCheckpointCount());
latch.await();
// index again while we are stale.
indexDoc();
refresh(INDEX_NAME);
totalDocs.incrementAndGet();

// Verify that replica shard is closed.
assertBusy(() -> { assertTrue(replicaShard.state().equals(IndexShardState.CLOSED)); }, 1, TimeUnit.MINUTES);
}
ensureGreen(INDEX_NAME);
final IndexShard replicaAfterFailure = getIndexShard(replicaNode, INDEX_NAME);

// Verify that new replica shard after failure is different from old replica shard.
assertNotEquals(replicaAfterFailure.routingEntry().allocationId().getId(), replicaShard.routingEntry().allocationId().getId());
}

public void testBulkWritesRejected() throws Exception {
final String primaryNode = internalCluster().startNode();
createIndex(INDEX_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,25 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.cluster.action.shard.ShardStateAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.Tuple;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AbstractAsyncTask;
import org.opensearch.core.concurrency.OpenSearchRejectedExecutionException;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.ShardId;
import org.opensearch.indices.IndicesService;
import org.opensearch.threadpool.ThreadPool;

import java.io.Closeable;
import java.io.IOException;
import java.util.Comparator;
import java.util.Set;
import java.util.stream.Collectors;

Expand All @@ -29,7 +37,7 @@
*
* @opensearch.internal
*/
public class SegmentReplicationPressureService {
public class SegmentReplicationPressureService implements Closeable {

private volatile boolean isSegmentReplicationBackpressureEnabled;
private volatile int maxCheckpointsBehind;
Expand All @@ -38,6 +46,10 @@ public class SegmentReplicationPressureService {

private static final Logger logger = LogManager.getLogger(SegmentReplicationPressureService.class);

/**
* When enabled, writes will be rejected when a replica shard falls behind by both the MAX_REPLICATION_TIME_SETTING time value and MAX_INDEXING_CHECKPOINTS number of checkpoints.
* Once a shard falls behind double the MAX_REPLICATION_TIME_SETTING time value it will be marked as failed.
*/
public static final Setting<Boolean> SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED = Setting.boolSetting(
"segrep.pressure.enabled",
false,
Expand Down Expand Up @@ -70,12 +82,25 @@ public class SegmentReplicationPressureService {
);

private final IndicesService indicesService;

private final ThreadPool threadPool;
private final SegmentReplicationStatsTracker tracker;
private final ShardStateAction shardStateAction;

private final AsyncFailStaleReplicaTask failStaleReplicaTask;

@Inject
public SegmentReplicationPressureService(Settings settings, ClusterService clusterService, IndicesService indicesService) {
public SegmentReplicationPressureService(
Settings settings,
ClusterService clusterService,
IndicesService indicesService,
ShardStateAction shardStateAction,
ThreadPool threadPool
) {
this.indicesService = indicesService;
this.tracker = new SegmentReplicationStatsTracker(this.indicesService);
this.shardStateAction = shardStateAction;
this.threadPool = threadPool;

final ClusterSettings clusterSettings = clusterService.getClusterSettings();
this.isSegmentReplicationBackpressureEnabled = SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.get(settings);
Expand All @@ -92,6 +117,13 @@ public SegmentReplicationPressureService(Settings settings, ClusterService clust

this.maxAllowedStaleReplicas = MAX_ALLOWED_STALE_SHARDS.get(settings);
clusterSettings.addSettingsUpdateConsumer(MAX_ALLOWED_STALE_SHARDS, this::setMaxAllowedStaleReplicas);

this.failStaleReplicaTask = new AsyncFailStaleReplicaTask(this);
}

// visible for testing
AsyncFailStaleReplicaTask getFailStaleReplicaTask() {
return failStaleReplicaTask;
}

public void isSegrepLimitBreached(ShardId shardId) {
Expand Down Expand Up @@ -154,4 +186,95 @@ public void setMaxAllowedStaleReplicas(double maxAllowedStaleReplicas) {
public void setMaxReplicationTime(TimeValue maxReplicationTime) {
this.maxReplicationTime = maxReplicationTime;
}

@Override
public void close() throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this newly added method get called?

Copy link
Member Author

@Rishikesh1159 Rishikesh1159 Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, it is not called/used directly from anywhere. I took the reference from PersistentTasksClusterService, just to make sure when service is closed this async task is also closed.

failStaleReplicaTask.close();
}

// Background Task to fail replica shards if they are too far behind primary shard.
final static class AsyncFailStaleReplicaTask extends AbstractAsyncTask {

final SegmentReplicationPressureService pressureService;

static final TimeValue INTERVAL = TimeValue.timeValueSeconds(30);

AsyncFailStaleReplicaTask(SegmentReplicationPressureService pressureService) {
super(logger, pressureService.threadPool, INTERVAL, true);
this.pressureService = pressureService;
rescheduleIfNecessary();
}

@Override
protected boolean mustReschedule() {
return true;
}

@Override
protected void runInternal() {
if (pressureService.isSegmentReplicationBackpressureEnabled) {
final SegmentReplicationStats stats = pressureService.tracker.getStats();

// Find the shardId in node which is having stale replicas with highest current replication time.
// This way we only fail one shardId's stale replicas in every iteration of this background async task and there by decrease
// load gradually on node.
stats.getShardStats()
.entrySet()
.stream()
.flatMap(
entry -> pressureService.getStaleReplicas(entry.getValue().getReplicaStats())
.stream()
.map(r -> Tuple.tuple(entry.getKey(), r.getCurrentReplicationTimeMillis()))
)
.max(Comparator.comparingLong(Tuple::v2))
.map(Tuple::v1)
.ifPresent(shardId -> {
final Set<SegmentReplicationShardStats> staleReplicas = pressureService.getStaleReplicas(
stats.getShardStats().get(shardId).getReplicaStats()
);
final IndexService indexService = pressureService.indicesService.indexService(shardId.getIndex());
final IndexShard primaryShard = indexService.getShard(shardId.getId());
for (SegmentReplicationShardStats staleReplica : staleReplicas) {
if (staleReplica.getCurrentReplicationTimeMillis() > 2 * pressureService.maxReplicationTime.millis()) {
pressureService.shardStateAction.remoteShardFailed(
shardId,
staleReplica.getAllocationId(),
primaryShard.getOperationPrimaryTerm(),
true,
"replica too far behind primary, marking as stale",
null,
new ActionListener<>() {
@Override
public void onResponse(Void unused) {
logger.trace(
"Successfully failed remote shardId [{}] allocation id [{}]",
shardId,
staleReplica.getAllocationId()
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to send remote shard failure", e);
}
}
);
}
}
});
}
}

@Override
protected String getThreadPool() {
return ThreadPool.Names.GENERIC;
}

@Override
public String toString() {
return "fail_stale_replica";
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

package org.opensearch.index;

import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import org.opensearch.cluster.action.shard.ShardStateAction;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
Expand All @@ -21,6 +23,7 @@
import org.opensearch.index.shard.ShardId;
import org.opensearch.indices.IndicesService;
import org.opensearch.indices.replication.common.ReplicationType;
import org.opensearch.threadpool.ThreadPool;

import java.util.Iterator;
import java.util.List;
Expand All @@ -29,13 +32,20 @@
import java.util.concurrent.TimeUnit;

import static java.util.Arrays.asList;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.times;
import static org.opensearch.index.SegmentReplicationPressureService.MAX_REPLICATION_TIME_SETTING;
import static org.opensearch.index.SegmentReplicationPressureService.SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED;

public class SegmentReplicationPressureServiceTests extends OpenSearchIndexLevelReplicationTestCase {

private static ShardStateAction shardStateAction = Mockito.mock(ShardStateAction.class);
private static final Settings settings = Settings.builder()
.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
.put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
Expand Down Expand Up @@ -181,6 +191,36 @@ public void testIsSegrepLimitBreached_underStaleNodeLimit() throws Exception {
}
}

public void testFailStaleReplicaTask() throws Exception {
final Settings settings = Settings.builder()
.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
.put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
.put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(10))
.build();

try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) {
shards.startAll();
final IndexShard primaryShard = shards.getPrimary();
SegmentReplicationPressureService service = buildPressureService(settings, primaryShard);

// index docs in batches without refreshing
indexInBatches(5, shards, primaryShard);

// assert that replica shard is few checkpoints behind primary
Set<SegmentReplicationShardStats> replicationStats = primaryShard.getReplicationStats();
assertEquals(1, replicationStats.size());
SegmentReplicationShardStats shardStats = replicationStats.stream().findFirst().get();
assertEquals(5, shardStats.getCheckpointsBehindCount());

// call the background task
service.getFailStaleReplicaTask().runInternal();

// verify that remote shard failed method is called which fails the replica shards falling behind.
verify(shardStateAction, times(1)).remoteShardFailed(any(), anyString(), anyLong(), anyBoolean(), anyString(), any(), any());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Instead of random values, we can use actual values i.e. shardId, allocationId, primaryOperationTerm, true, "replica too far behind primary, marking as stale",, null.

Copy link
Member Author

@Rishikesh1159 Rishikesh1159 Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried doing this, but to verify exact values we need the reference of listener (last parameter), we cannot pass it as verify would fail and also I cannot do any() for just last parameter because we cannot combine any() with other parameters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a race condition here? Is it possible for the explicit runInternal() call to happen as well as the one that gets scheduled by default which would then cause the times(1) assertion to fail?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Instead of random values, we can use actual values i.e. shardId, allocationId, primaryOperationTerm, true, "replica too far behind primary, marking as stale",, null.

If we try to combine any() with other regular parameters in method call, then we get an exception.

This exception may occur if matchers are combined with raw values:
    //incorrect:
    someMethod(any(), "raw String");
When using matchers, all arguments have to be provided by matchers.
For example:
    //correct:
    someMethod(any(), eq("String by matcher"));

For more info see javadoc for Matchers class.

Copy link
Member Author

@Rishikesh1159 Rishikesh1159 Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a race condition here? Is it possible for the explicit runInternal() call to happen as well as the one that gets scheduled by default which would then cause the times(1) assertion to fail?

No there is no race condition in this unit test. So, there are two ways to start background task either use task.runInternal() or by rescheduleIfNecessary(). In actual cluster or in an integration test when new instance of class is created and we start the background task. But here as this is a unit test and we are using mocks, so rescheduleIfNecessary() would not trigger the runInternal() because the class is just a mock. So, to verify that task actually works in unit test, I am making an explicit service.getFailStaleReplicaTask().runInternal(). So, this way task will run only once.

replicateSegments(primaryShard, shards.getReplicas());
}
}

private int indexInBatches(int count, ReplicationGroup shards, IndexShard primaryShard) throws Exception {
int totalDocs = 0;
for (int i = 0; i < count; i++) {
Expand All @@ -202,6 +242,6 @@ private SegmentReplicationPressureService buildPressureService(Settings settings
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS));

return new SegmentReplicationPressureService(settings, clusterService, indicesService);
return new SegmentReplicationPressureService(settings, clusterService, indicesService, shardStateAction, mock(ThreadPool.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1980,7 +1980,13 @@ public void onFailure(final Exception e) {
new UpdateHelper(scriptService),
actionFilters,
new IndexingPressureService(settings, clusterService),
new SegmentReplicationPressureService(settings, clusterService, mock(IndicesService.class)),
new SegmentReplicationPressureService(
settings,
clusterService,
mock(IndicesService.class),
mock(ShardStateAction.class),
mock(ThreadPool.class)
),
new SystemIndices(emptyMap())
);
actions.put(
Expand Down