Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Segment Replication] Add new background task to fail stale replica shards. #6850

Merged
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.IndexShardState;
import org.opensearch.indices.replication.SegmentReplicationBaseIT;
import org.opensearch.indices.replication.common.ReplicationType;
import org.opensearch.plugins.Plugin;
Expand All @@ -29,6 +30,7 @@
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

import static java.util.Arrays.asList;
Expand Down Expand Up @@ -200,6 +202,44 @@ public void testBelowReplicaLimit() throws Exception {
verifyStoreContent();
}

public void testFailStaleReplica() throws Exception {

Settings settings = Settings.builder().put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(500)).build();
// Starts a primary and replica node.
final String primaryNode = internalCluster().startNode(settings);
createIndex(INDEX_NAME);
ensureYellowAndNoInitializingShards(INDEX_NAME);
final String replicaNode = internalCluster().startNode(settings);
ensureGreen(INDEX_NAME);

final IndexShard primaryShard = getIndexShard(primaryNode, INDEX_NAME);
final List<String> replicaNodes = asList(replicaNode);
assertEqualSegmentInfosVersion(replicaNodes, primaryShard);
IndexShard replicaShard = getIndexShard(replicaNode, INDEX_NAME);

final CountDownLatch latch = new CountDownLatch(1);
final AtomicInteger totalDocs = new AtomicInteger(0);
try (final Releasable ignored = blockReplication(replicaNodes, latch)) {
// Index docs until replicas are staled.
Thread indexingThread = new Thread(() -> { totalDocs.getAndSet(indexUntilCheckpointCount()); });
indexingThread.start();
indexingThread.join();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use another thread if you immediately block on it to complete? Can you just replace this with totalDocs.getAndSet(indexUntilCheckpointCount()); ?

latch.await();
// index again while we are stale.
indexDoc();
refresh(INDEX_NAME);
totalDocs.incrementAndGet();

// Verify that replica shard is closed.
assertBusy(() -> { assertTrue(replicaShard.state().equals(IndexShardState.CLOSED)); }, 1, TimeUnit.MINUTES);
}
ensureGreen(INDEX_NAME);
final IndexShard replicaAfterFailure = getIndexShard(replicaNode, INDEX_NAME);

// Verify that new replica shard after failure is different from old replica shard.
assertNotEquals(replicaAfterFailure.routingEntry().allocationId().getId(), replicaShard.routingEntry().allocationId().getId());
}

public void testBulkWritesRejected() throws Exception {
final String primaryNode = internalCluster().startNode();
createIndex(INDEX_NAME);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,33 @@

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionListener;
import org.opensearch.cluster.action.shard.ShardStateAction;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AbstractAsyncTask;
import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.ShardId;
import org.opensearch.indices.IndicesService;
import org.opensearch.threadpool.ThreadPool;

import java.io.Closeable;
import java.io.IOException;
import java.util.Set;
import java.util.Map;
import java.util.stream.Collectors;

/**
* Service responsible for applying backpressure for lagging behind replicas when Segment Replication is enabled.
*
* @opensearch.internal
*/
public class SegmentReplicationPressureService {
public class SegmentReplicationPressureService implements Closeable {

private volatile boolean isSegmentReplicationBackpressureEnabled;
private volatile int maxCheckpointsBehind;
Expand Down Expand Up @@ -70,12 +77,30 @@ public class SegmentReplicationPressureService {
);

private final IndicesService indicesService;

private final ThreadPool threadPool;
private final SegmentReplicationStatsTracker tracker;

private final ShardStateAction shardStateAction;

public AsyncFailStaleReplicaTask getFailStaleReplicaTask() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move this below the constructor next to the other methods in this class. Also, can this be package private? Maybe comment as "visible for testing" if it only exists for testing purposes.

return failStaleReplicaTask;
}

private volatile AsyncFailStaleReplicaTask failStaleReplicaTask;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why volatile? It looks like it is only set once by the constructor, so can it be final?


@Inject
public SegmentReplicationPressureService(Settings settings, ClusterService clusterService, IndicesService indicesService) {
public SegmentReplicationPressureService(
Settings settings,
ClusterService clusterService,
IndicesService indicesService,
ShardStateAction shardStateAction,
ThreadPool threadPool
) {
this.indicesService = indicesService;
this.tracker = new SegmentReplicationStatsTracker(this.indicesService);
this.shardStateAction = shardStateAction;
this.threadPool = threadPool;

final ClusterSettings clusterSettings = clusterService.getClusterSettings();
this.isSegmentReplicationBackpressureEnabled = SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.get(settings);
Expand All @@ -92,6 +117,8 @@ public SegmentReplicationPressureService(Settings settings, ClusterService clust

this.maxAllowedStaleReplicas = MAX_ALLOWED_STALE_SHARDS.get(settings);
clusterSettings.addSettingsUpdateConsumer(MAX_ALLOWED_STALE_SHARDS, this::setMaxAllowedStaleReplicas);

this.failStaleReplicaTask = new AsyncFailStaleReplicaTask(TimeValue.timeValueSeconds(30));
}

public void isSegrepLimitBreached(ShardId shardId) {
Expand Down Expand Up @@ -154,4 +181,70 @@ public void setMaxAllowedStaleReplicas(double maxAllowedStaleReplicas) {
public void setMaxReplicationTime(TimeValue maxReplicationTime) {
this.maxReplicationTime = maxReplicationTime;
}

@Override
public void close() throws IOException {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this newly added method get called?

Copy link
Member Author

@Rishikesh1159 Rishikesh1159 Apr 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, it is not called/used directly from anywhere. I took the reference from PersistentTasksClusterService, just to make sure when service is closed this async task is also closed.

failStaleReplicaTask.close();
}

// Background Task to fail replica shards if they are too far behind primary shard.
final class AsyncFailStaleReplicaTask extends AbstractAsyncTask {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The base class defaults to ThreadPool.SAME, so which thread pool will this run on? It looks like it initially gets scheduled during initialization so I honestly don't know what thread pool that is. I think it is worth being explicit by overriding getThreadPool()

Another point, and this is a bit of a style concern, but I really dislike non-static inner classes in Java because it implicitly binds the instance to the this pointer of the outer class instance that creates it. I would recommend making this static. If you need access to lots of members of the outer class, you can explicitly pass in a reference like is done here.

Copy link
Member Author

@Rishikesh1159 Rishikesh1159 Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrross thanks for review. Yes I think it uses default ThreadPool.SAME as you said. I think using generic pool ThreadPool.Names.GENERIC here might make sense as we do here for few other async tasks, please let me know if you have other thoughts here.

And for style concern, I have updated the inner class to static as you said.


AsyncFailStaleReplicaTask(TimeValue interval) {
super(logger, threadPool, interval, true);
rescheduleIfNecessary();
}

@Override
protected boolean mustReschedule() {
return true;
}

@Override
protected void runInternal() {
if (isSegmentReplicationBackpressureEnabled) {
final SegmentReplicationStats stats = tracker.getStats();
for (Map.Entry<ShardId, SegmentReplicationPerGroupStats> entry : stats.getShardStats().entrySet()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have you considered not being so aggressive with failing every lagging shard on a single execution? Like, lets say you've got one really hot shard that is causing the host to brown out, so all shards start to lag behind. Once they cross the threshold then you'll fail them all and the host will then be more-or-less idle. An alternative would be to fail at most 1 shard per iteration (or some other fixed number other than 1), so it would more gradually shed load.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @andrross for your input. I think what you said makes sense. I have updated the background task logic to fail stale replicas of only one shardId's in a single iteration of background task. This way we can gradually reduce the load on node.

final Set<SegmentReplicationShardStats> staleReplicas = getStaleReplicas(entry.getValue().getReplicaStats());
final ShardId shardId = entry.getKey();
final IndexService indexService = indicesService.indexService(shardId.getIndex());
final IndexShard primaryShard = indexService.getShard(shardId.getId());
for (SegmentReplicationShardStats staleReplica : staleReplicas) {
if (staleReplica.getCurrentReplicationTimeMillis() > 2 * maxReplicationTime.millis()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What would the default value be for max replication time? Is there any documentation we're providing on how to set this value when people enable backpressure?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default value for max replication time is 5min. For documentation, we need to open an issue with documentation repo documenting all changes/usages of segrep backpressure. I will open that issue after this PR gets merged in.

shardStateAction.remoteShardFailed(
shardId,
staleReplica.getAllocationId(),
primaryShard.getOperationPrimaryTerm(),
true,
"replica too far behind primary, marking as stale",
null,
new ActionListener<>() {
@Override
public void onResponse(Void unused) {
logger.trace(
"Successfully failed remote shardId [{}] allocation id [{}]",
shardId,
staleReplica.getAllocationId()
);
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to send remote shard failure", e);
}
}
);
}
}
}
}
}

@Override
public String toString() {
return "fail_stale_replica";
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@

package org.opensearch.index;

import org.mockito.Mockito;
import org.mockito.stubbing.Answer;
import org.opensearch.cluster.action.shard.ShardStateAction;
import org.opensearch.cluster.metadata.IndexMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.ClusterSettings;
Expand All @@ -21,6 +23,7 @@
import org.opensearch.index.shard.ShardId;
import org.opensearch.indices.IndicesService;
import org.opensearch.indices.replication.common.ReplicationType;
import org.opensearch.threadpool.ThreadPool;

import java.util.Iterator;
import java.util.List;
Expand All @@ -29,13 +32,20 @@
import java.util.concurrent.TimeUnit;

import static java.util.Arrays.asList;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.times;
import static org.opensearch.index.SegmentReplicationPressureService.MAX_REPLICATION_TIME_SETTING;
import static org.opensearch.index.SegmentReplicationPressureService.SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED;

public class SegmentReplicationPressureServiceTests extends OpenSearchIndexLevelReplicationTestCase {

private static ShardStateAction shardStateAction = Mockito.mock(ShardStateAction.class);
private static final Settings settings = Settings.builder()
.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
.put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
Expand Down Expand Up @@ -181,6 +191,36 @@ public void testIsSegrepLimitBreached_underStaleNodeLimit() throws Exception {
}
}

public void testFailStaleReplicaTask() throws Exception {
final Settings settings = Settings.builder()
.put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT)
.put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true)
.put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(10))
.build();

try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) {
shards.startAll();
final IndexShard primaryShard = shards.getPrimary();
SegmentReplicationPressureService service = buildPressureService(settings, primaryShard);

// index docs in batches without refreshing
indexInBatches(5, shards, primaryShard);

// assert that replica shard is few checkpoints behind primary
Set<SegmentReplicationShardStats> replicationStats = primaryShard.getReplicationStats();
assertEquals(1, replicationStats.size());
SegmentReplicationShardStats shardStats = replicationStats.stream().findFirst().get();
assertEquals(5, shardStats.getCheckpointsBehindCount());

// call the background task
service.getFailStaleReplicaTask().runInternal();

// verify that remote shard failed method is called which fails the replica shards falling behind.
verify(shardStateAction, times(1)).remoteShardFailed(any(), anyString(), anyLong(), anyBoolean(), anyString(), any(), any());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Instead of random values, we can use actual values i.e. shardId, allocationId, primaryOperationTerm, true, "replica too far behind primary, marking as stale",, null.

Copy link
Member Author

@Rishikesh1159 Rishikesh1159 Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried doing this, but to verify exact values we need the reference of listener (last parameter), we cannot pass it as verify would fail and also I cannot do any() for just last parameter because we cannot combine any() with other parameters.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a race condition here? Is it possible for the explicit runInternal() call to happen as well as the one that gets scheduled by default which would then cause the times(1) assertion to fail?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Instead of random values, we can use actual values i.e. shardId, allocationId, primaryOperationTerm, true, "replica too far behind primary, marking as stale",, null.

If we try to combine any() with other regular parameters in method call, then we get an exception.

This exception may occur if matchers are combined with raw values:
    //incorrect:
    someMethod(any(), "raw String");
When using matchers, all arguments have to be provided by matchers.
For example:
    //correct:
    someMethod(any(), eq("String by matcher"));

For more info see javadoc for Matchers class.

Copy link
Member Author

@Rishikesh1159 Rishikesh1159 Mar 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a race condition here? Is it possible for the explicit runInternal() call to happen as well as the one that gets scheduled by default which would then cause the times(1) assertion to fail?

No there is no race condition in this unit test. So, there are two ways to start background task either use task.runInternal() or by rescheduleIfNecessary(). In actual cluster or in an integration test when new instance of class is created and we start the background task. But here as this is a unit test and we are using mocks, so rescheduleIfNecessary() would not trigger the runInternal() because the class is just a mock. So, to verify that task actually works in unit test, I am making an explicit service.getFailStaleReplicaTask().runInternal(). So, this way task will run only once.

replicateSegments(primaryShard, shards.getReplicas());
}
}

private int indexInBatches(int count, ReplicationGroup shards, IndexShard primaryShard) throws Exception {
int totalDocs = 0;
for (int i = 0; i < count; i++) {
Expand All @@ -202,6 +242,6 @@ private SegmentReplicationPressureService buildPressureService(Settings settings
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS));

return new SegmentReplicationPressureService(settings, clusterService, indicesService);
return new SegmentReplicationPressureService(settings, clusterService, indicesService, shardStateAction, mock(ThreadPool.class));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -1980,7 +1980,13 @@ public void onFailure(final Exception e) {
new UpdateHelper(scriptService),
actionFilters,
new IndexingPressureService(settings, clusterService),
new SegmentReplicationPressureService(settings, clusterService, mock(IndicesService.class)),
new SegmentReplicationPressureService(
settings,
clusterService,
mock(IndicesService.class),
mock(ShardStateAction.class),
mock(ThreadPool.class)
),
new SystemIndices(emptyMap())
);
actions.put(
Expand Down