diff --git a/server/src/internalClusterTest/java/org/opensearch/index/SegmentReplicationPressureIT.java b/server/src/internalClusterTest/java/org/opensearch/index/SegmentReplicationPressureIT.java index ad6c396df69a1..35d6a9ef0ef1d 100644 --- a/server/src/internalClusterTest/java/org/opensearch/index/SegmentReplicationPressureIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/index/SegmentReplicationPressureIT.java @@ -16,6 +16,7 @@ import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; import org.opensearch.index.shard.IndexShard; +import org.opensearch.index.shard.IndexShardState; import org.opensearch.indices.replication.SegmentReplicationBaseIT; import org.opensearch.indices.replication.common.ReplicationType; import org.opensearch.plugins.Plugin; @@ -29,6 +30,7 @@ import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import static java.util.Arrays.asList; @@ -200,6 +202,42 @@ public void testBelowReplicaLimit() throws Exception { verifyStoreContent(); } + public void testFailStaleReplica() throws Exception { + + Settings settings = Settings.builder().put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(500)).build(); + // Starts a primary and replica node. + final String primaryNode = internalCluster().startNode(settings); + createIndex(INDEX_NAME); + ensureYellowAndNoInitializingShards(INDEX_NAME); + final String replicaNode = internalCluster().startNode(settings); + ensureGreen(INDEX_NAME); + + final IndexShard primaryShard = getIndexShard(primaryNode, INDEX_NAME); + final List replicaNodes = asList(replicaNode); + assertEqualSegmentInfosVersion(replicaNodes, primaryShard); + IndexShard replicaShard = getIndexShard(replicaNode, INDEX_NAME); + + final CountDownLatch latch = new CountDownLatch(1); + final AtomicInteger totalDocs = new AtomicInteger(0); + try (final Releasable ignored = blockReplication(replicaNodes, latch)) { + // Index docs until replicas are staled. + totalDocs.getAndSet(indexUntilCheckpointCount()); + latch.await(); + // index again while we are stale. + indexDoc(); + refresh(INDEX_NAME); + totalDocs.incrementAndGet(); + + // Verify that replica shard is closed. + assertBusy(() -> { assertTrue(replicaShard.state().equals(IndexShardState.CLOSED)); }, 1, TimeUnit.MINUTES); + } + ensureGreen(INDEX_NAME); + final IndexShard replicaAfterFailure = getIndexShard(replicaNode, INDEX_NAME); + + // Verify that new replica shard after failure is different from old replica shard. + assertNotEquals(replicaAfterFailure.routingEntry().allocationId().getId(), replicaShard.routingEntry().allocationId().getId()); + } + public void testBulkWritesRejected() throws Exception { final String primaryNode = internalCluster().startNode(); createIndex(INDEX_NAME); diff --git a/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java b/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java index f31e236fb6184..7117836ce7873 100644 --- a/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java +++ b/server/src/main/java/org/opensearch/index/SegmentReplicationPressureService.java @@ -10,17 +10,25 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.cluster.action.shard.ShardStateAction; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.collect.Tuple; import org.opensearch.common.inject.Inject; import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; +import org.opensearch.common.util.concurrent.AbstractAsyncTask; import org.opensearch.common.util.concurrent.OpenSearchRejectedExecutionException; import org.opensearch.index.shard.IndexShard; import org.opensearch.index.shard.ShardId; import org.opensearch.indices.IndicesService; +import org.opensearch.threadpool.ThreadPool; +import java.io.Closeable; +import java.io.IOException; +import java.util.Comparator; import java.util.Set; import java.util.stream.Collectors; @@ -29,7 +37,7 @@ * * @opensearch.internal */ -public class SegmentReplicationPressureService { +public class SegmentReplicationPressureService implements Closeable { private volatile boolean isSegmentReplicationBackpressureEnabled; private volatile int maxCheckpointsBehind; @@ -38,6 +46,10 @@ public class SegmentReplicationPressureService { private static final Logger logger = LogManager.getLogger(SegmentReplicationPressureService.class); + /** + * When enabled, writes will be rejected when a replica shard falls behind by both the MAX_REPLICATION_TIME_SETTING time value and MAX_INDEXING_CHECKPOINTS number of checkpoints. + * Once a shard falls behind double the MAX_REPLICATION_TIME_SETTING time value it will be marked as failed. + */ public static final Setting SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED = Setting.boolSetting( "segrep.pressure.enabled", false, @@ -70,13 +82,28 @@ public class SegmentReplicationPressureService { ); private final IndicesService indicesService; + + private final ThreadPool threadPool; private final SegmentReplicationStatsTracker tracker; + private final ShardStateAction shardStateAction; + + private final AsyncFailStaleReplicaTask failStaleReplicaTask; + @Inject - public SegmentReplicationPressureService(Settings settings, ClusterService clusterService, IndicesService indicesService) { + public SegmentReplicationPressureService( + Settings settings, + ClusterService clusterService, + IndicesService indicesService, + ShardStateAction shardStateAction, + ThreadPool threadPool + ) { this.indicesService = indicesService; this.tracker = new SegmentReplicationStatsTracker(this.indicesService); + this.shardStateAction = shardStateAction; + this.threadPool = threadPool; + final ClusterSettings clusterSettings = clusterService.getClusterSettings(); this.isSegmentReplicationBackpressureEnabled = SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.get(settings); clusterSettings.addSettingsUpdateConsumer( @@ -92,6 +119,13 @@ public SegmentReplicationPressureService(Settings settings, ClusterService clust this.maxAllowedStaleReplicas = MAX_ALLOWED_STALE_SHARDS.get(settings); clusterSettings.addSettingsUpdateConsumer(MAX_ALLOWED_STALE_SHARDS, this::setMaxAllowedStaleReplicas); + + this.failStaleReplicaTask = new AsyncFailStaleReplicaTask(this); + } + + // visible for testing + AsyncFailStaleReplicaTask getFailStaleReplicaTask() { + return failStaleReplicaTask; } public void isSegrepLimitBreached(ShardId shardId) { @@ -154,4 +188,94 @@ public void setMaxAllowedStaleReplicas(double maxAllowedStaleReplicas) { public void setMaxReplicationTime(TimeValue maxReplicationTime) { this.maxReplicationTime = maxReplicationTime; } + + @Override + public void close() throws IOException { + failStaleReplicaTask.close(); + } + + // Background Task to fail replica shards if they are too far behind primary shard. + final static class AsyncFailStaleReplicaTask extends AbstractAsyncTask { + + final SegmentReplicationPressureService pressureService; + + static final TimeValue INTERVAL = TimeValue.timeValueSeconds(30); + + AsyncFailStaleReplicaTask(SegmentReplicationPressureService pressureService) { + super(logger, pressureService.threadPool, INTERVAL, true); + this.pressureService = pressureService; + rescheduleIfNecessary(); + } + + @Override + protected boolean mustReschedule() { + return true; + } + + @Override + protected void runInternal() { + if (pressureService.isSegmentReplicationBackpressureEnabled) { + final SegmentReplicationStats stats = pressureService.tracker.getStats(); + + // Find the shardId in node which is having stale replicas with highest current replication time. + // This way we only fail one shardId's stale replicas in every iteration of this background async task and there by decrease + // load gradually on node. + stats.getShardStats() + .entrySet() + .stream() + .flatMap( + entry -> pressureService.getStaleReplicas(entry.getValue().getReplicaStats()) + .stream() + .map(r -> Tuple.tuple(entry.getKey(), r.getCurrentReplicationTimeMillis())) + ) + .max(Comparator.comparingLong(Tuple::v2)) + .map(Tuple::v1) + .ifPresent(shardId -> { + final Set staleReplicas = pressureService.getStaleReplicas( + stats.getShardStats().get(shardId).getReplicaStats() + ); + final IndexService indexService = pressureService.indicesService.indexService(shardId.getIndex()); + final IndexShard primaryShard = indexService.getShard(shardId.getId()); + for (SegmentReplicationShardStats staleReplica : staleReplicas) { + if (staleReplica.getCurrentReplicationTimeMillis() > 2 * pressureService.maxReplicationTime.millis()) { + pressureService.shardStateAction.remoteShardFailed( + shardId, + staleReplica.getAllocationId(), + primaryShard.getOperationPrimaryTerm(), + true, + "replica too far behind primary, marking as stale", + null, + new ActionListener<>() { + @Override + public void onResponse(Void unused) { + logger.trace( + "Successfully failed remote shardId [{}] allocation id [{}]", + shardId, + staleReplica.getAllocationId() + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to send remote shard failure", e); + } + } + ); + } + } + }); + } + } + + @Override + protected String getThreadPool() { + return ThreadPool.Names.GENERIC; + } + + @Override + public String toString() { + return "fail_stale_replica"; + } + + } } diff --git a/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java b/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java index a050a4c2243db..3bc84c2c44be8 100644 --- a/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java +++ b/server/src/test/java/org/opensearch/index/SegmentReplicationPressureServiceTests.java @@ -8,7 +8,9 @@ package org.opensearch.index; +import org.mockito.Mockito; import org.mockito.stubbing.Answer; +import org.opensearch.cluster.action.shard.ShardStateAction; import org.opensearch.cluster.metadata.IndexMetadata; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.ClusterSettings; @@ -21,6 +23,7 @@ import org.opensearch.index.shard.ShardId; import org.opensearch.indices.IndicesService; import org.opensearch.indices.replication.common.ReplicationType; +import org.opensearch.threadpool.ThreadPool; import java.util.Iterator; import java.util.List; @@ -29,13 +32,20 @@ import java.util.concurrent.TimeUnit; import static java.util.Arrays.asList; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.times; import static org.opensearch.index.SegmentReplicationPressureService.MAX_REPLICATION_TIME_SETTING; import static org.opensearch.index.SegmentReplicationPressureService.SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED; public class SegmentReplicationPressureServiceTests extends OpenSearchIndexLevelReplicationTestCase { + private static ShardStateAction shardStateAction = Mockito.mock(ShardStateAction.class); private static final Settings settings = Settings.builder() .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT) .put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true) @@ -181,6 +191,36 @@ public void testIsSegrepLimitBreached_underStaleNodeLimit() throws Exception { } } + public void testFailStaleReplicaTask() throws Exception { + final Settings settings = Settings.builder() + .put(IndexMetadata.SETTING_REPLICATION_TYPE, ReplicationType.SEGMENT) + .put(SEGMENT_REPLICATION_INDEXING_PRESSURE_ENABLED.getKey(), true) + .put(MAX_REPLICATION_TIME_SETTING.getKey(), TimeValue.timeValueMillis(10)) + .build(); + + try (ReplicationGroup shards = createGroup(1, settings, new NRTReplicationEngineFactory())) { + shards.startAll(); + final IndexShard primaryShard = shards.getPrimary(); + SegmentReplicationPressureService service = buildPressureService(settings, primaryShard); + + // index docs in batches without refreshing + indexInBatches(5, shards, primaryShard); + + // assert that replica shard is few checkpoints behind primary + Set replicationStats = primaryShard.getReplicationStats(); + assertEquals(1, replicationStats.size()); + SegmentReplicationShardStats shardStats = replicationStats.stream().findFirst().get(); + assertEquals(5, shardStats.getCheckpointsBehindCount()); + + // call the background task + service.getFailStaleReplicaTask().runInternal(); + + // verify that remote shard failed method is called which fails the replica shards falling behind. + verify(shardStateAction, times(1)).remoteShardFailed(any(), anyString(), anyLong(), anyBoolean(), anyString(), any(), any()); + replicateSegments(primaryShard, shards.getReplicas()); + } + } + private int indexInBatches(int count, ReplicationGroup shards, IndexShard primaryShard) throws Exception { int totalDocs = 0; for (int i = 0; i < count; i++) { @@ -202,6 +242,6 @@ private SegmentReplicationPressureService buildPressureService(Settings settings ClusterService clusterService = mock(ClusterService.class); when(clusterService.getClusterSettings()).thenReturn(new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS)); - return new SegmentReplicationPressureService(settings, clusterService, indicesService); + return new SegmentReplicationPressureService(settings, clusterService, indicesService, shardStateAction, mock(ThreadPool.class)); } } diff --git a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java index 40cd924928541..6c4b636e3c002 100644 --- a/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/opensearch/snapshots/SnapshotResiliencyTests.java @@ -1984,7 +1984,13 @@ public void onFailure(final Exception e) { new UpdateHelper(scriptService), actionFilters, new IndexingPressureService(settings, clusterService), - new SegmentReplicationPressureService(settings, clusterService, mock(IndicesService.class)), + new SegmentReplicationPressureService( + settings, + clusterService, + mock(IndicesService.class), + mock(ShardStateAction.class), + mock(ThreadPool.class) + ), new SystemIndices(emptyMap()) ); actions.put(