Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding onNewCheckpoint to Start Replication on Replica Shard when Segment Replication is turned on #3540

Merged
merged 16 commits into from
Jun 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ public final EngineConfig config() {
* Return the latest active SegmentInfos from the engine.
* @return {@link SegmentInfos}
*/
@Nullable
protected abstract SegmentInfos getLatestSegmentInfos();

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2289,7 +2289,7 @@ protected SegmentInfos getLastCommittedSegmentInfos() {
}

@Override
public SegmentInfos getLatestSegmentInfos() {
protected SegmentInfos getLatestSegmentInfos() {
OpenSearchDirectoryReader reader = null;
try {
reader = internalReaderManager.acquire();
Expand Down
59 changes: 52 additions & 7 deletions server/src/main/java/org/opensearch/index/shard/IndexShard.java
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@
import org.opensearch.indices.recovery.RecoveryListener;
import org.opensearch.indices.recovery.RecoveryState;
import org.opensearch.indices.recovery.RecoveryTarget;
import org.opensearch.indices.replication.checkpoint.PublishCheckpointRequest;
import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher;
import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint;
import org.opensearch.indices.replication.checkpoint.SegmentReplicationCheckpointPublisher;
import org.opensearch.repositories.RepositoriesService;
Expand Down Expand Up @@ -1396,15 +1396,60 @@ public GatedCloseable<IndexCommit> acquireSafeIndexCommit() throws EngineExcepti
* Returns the lastest Replication Checkpoint that shard received
*/
public ReplicationCheckpoint getLatestReplicationCheckpoint() {
return new ReplicationCheckpoint(shardId, 0, 0, 0, 0);
try (final GatedCloseable<SegmentInfos> snapshot = getSegmentInfosSnapshot()) {
return Optional.ofNullable(snapshot.get())
.map(
segmentInfos -> new ReplicationCheckpoint(
this.shardId,
getOperationPrimaryTerm(),
segmentInfos.getGeneration(),
getProcessedLocalCheckpoint(),
segmentInfos.getVersion()
)
)
.orElse(
new ReplicationCheckpoint(
shardId,
getOperationPrimaryTerm(),
SequenceNumbers.NO_OPS_PERFORMED,
getProcessedLocalCheckpoint(),
SequenceNumbers.NO_OPS_PERFORMED
)
);
} catch (IOException ex) {
throw new OpenSearchException("Error Closing SegmentInfos Snapshot", ex);
}
}

/**
* Invoked when a new checkpoint is received from a primary shard. Starts the copy process.
*/
public synchronized void onNewCheckpoint(final PublishCheckpointRequest request) {
assert shardRouting.primary() == false;
// TODO
* Checks if checkpoint should be processed
*
* @param requestCheckpoint received checkpoint that is checked for processing
* @return true if checkpoint should be processed
*/
public final boolean shouldProcessCheckpoint(ReplicationCheckpoint requestCheckpoint) {
if (state().equals(IndexShardState.STARTED) == false) {
logger.trace(() -> new ParameterizedMessage("Ignoring new replication checkpoint - shard is not started {}", state()));
return false;
}
ReplicationCheckpoint localCheckpoint = getLatestReplicationCheckpoint();
if (localCheckpoint.isAheadOf(requestCheckpoint)) {
logger.trace(
kartg marked this conversation as resolved.
Show resolved Hide resolved
() -> new ParameterizedMessage(
"Ignoring new replication checkpoint - Shard is already on checkpoint {} that is ahead of {}",
localCheckpoint,
requestCheckpoint
)
);
return false;
}
if (localCheckpoint.equals(requestCheckpoint)) {
logger.trace(
() -> new ParameterizedMessage("Ignoring new replication checkpoint - Shard is already on checkpoint {}", requestCheckpoint)
);
return false;
}
return true;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,8 @@ protected void configure() {
bind(RetentionLeaseSyncer.class).asEagerSingleton();
if (FeatureFlags.isEnabled(FeatureFlags.REPLICATION_TYPE)) {
bind(SegmentReplicationCheckpointPublisher.class).asEagerSingleton();
} else {
bind(SegmentReplicationCheckpointPublisher.class).toInstance(SegmentReplicationCheckpointPublisher.EMPTY);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
*
* @opensearch.internal
*/
public final class SegmentReplicationTargetService implements IndexEventListener {
mch2 marked this conversation as resolved.
Show resolved Hide resolved
public class SegmentReplicationTargetService implements IndexEventListener {

private static final Logger logger = LogManager.getLogger(SegmentReplicationTargetService.class);

Expand Down Expand Up @@ -84,6 +84,39 @@ public void beforeIndexShardClosed(ShardId shardId, @Nullable IndexShard indexSh
}
}

/**
* Invoked when a new checkpoint is received from a primary shard.
* It checks if a new checkpoint should be processed or not and starts replication if needed.
* @param receivedCheckpoint received checkpoint that is checked for processing
* @param replicaShard replica shard on which checkpoint is received
*/
public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedCheckpoint, final IndexShard replicaShard) {
if (onGoingReplications.isShardReplicating(replicaShard.shardId())) {
logger.trace(
mch2 marked this conversation as resolved.
Show resolved Hide resolved
() -> new ParameterizedMessage(
"Ignoring new replication checkpoint - shard is currently replicating to checkpoint {}",
replicaShard.getLatestReplicationCheckpoint()
)
);
return;
}
if (replicaShard.shouldProcessCheckpoint(receivedCheckpoint)) {
startReplication(receivedCheckpoint, replicaShard, new SegmentReplicationListener() {
@Override
public void onReplicationDone(SegmentReplicationState state) {}

@Override
public void onReplicationFailure(SegmentReplicationState state, OpenSearchException e, boolean sendShardFailure) {
if (sendShardFailure == true) {
logger.error("replication failure", e);
replicaShard.failShard("replication failure", e);
}
}
});

}
}

public void startReplication(
final ReplicationCheckpoint checkpoint,
final IndexShard indexShard,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.IndexShardClosedException;
import org.opensearch.indices.IndicesService;
import org.opensearch.indices.replication.SegmentReplicationTargetService;
import org.opensearch.node.NodeClosedException;
import org.opensearch.tasks.Task;
import org.opensearch.threadpool.ThreadPool;
Expand All @@ -52,6 +53,8 @@ public class PublishCheckpointAction extends TransportReplicationAction<
public static final String ACTION_NAME = "indices:admin/publishCheckpoint";
protected static Logger logger = LogManager.getLogger(PublishCheckpointAction.class);

private final SegmentReplicationTargetService replicationService;

@Inject
public PublishCheckpointAction(
Settings settings,
Expand All @@ -60,7 +63,8 @@ public PublishCheckpointAction(
IndicesService indicesService,
ThreadPool threadPool,
ShardStateAction shardStateAction,
ActionFilters actionFilters
ActionFilters actionFilters,
SegmentReplicationTargetService targetService
) {
super(
settings,
Expand All @@ -75,6 +79,7 @@ public PublishCheckpointAction(
PublishCheckpointRequest::new,
ThreadPool.Names.REFRESH
);
this.replicationService = targetService;
}

@Override
Expand Down Expand Up @@ -165,7 +170,7 @@ protected void shardOperationOnReplica(PublishCheckpointRequest request, IndexSh
ActionListener.completeWith(listener, () -> {
logger.trace("Checkpoint received on replica {}", request);
if (request.getCheckpoint().getShardId().equals(replica.shardId())) {
replica.onNewCheckpoint(request);
replicationService.onNewCheckpoint(request.getCheckpoint(), replica);
}
return new ReplicaResult();
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ public int hashCode() {
* Checks if other is aheadof current replication point by comparing segmentInfosVersion. Returns true for null
*/
public boolean isAheadOf(@Nullable ReplicationCheckpoint other) {
return other == null || segmentInfosVersion > other.getSegmentInfosVersion();
return other == null || segmentInfosVersion > other.getSegmentInfosVersion() || primaryTerm > other.getPrimaryTerm();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class SegmentReplicationCheckpointPublisher {

private final PublishAction publishAction;

// This Component is behind feature flag so we are manually binding this in IndicesModule.
@Inject
public SegmentReplicationCheckpointPublisher(PublishCheckpointAction publishAction) {
this(publishAction::publish);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,16 @@ public boolean cancelForShard(ShardId shardId, String reason) {
return cancelled;
}

/**
* check if a shard is currently replicating
*
* @param shardId shardId for which to check if replicating
* @return true if shard is currently replicating
*/
public boolean isShardReplicating(ShardId shardId) {
return onGoingTargetEvents.values().stream().anyMatch(t -> t.indexShard.shardId().equals(shardId));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit-pick - this is a problem to be solved down the line:

isShardReplicating seems like a frequent check, and constantly flattening the map to a stream of values may end up being performance-intensive. We should consider if we could make the key of the map the shardId to speed up this check.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes we can do that. But we have to change it in multiple places, so thinking of doing that in different PR

}

/**
* a reference to {@link ReplicationTarget}, which implements {@link AutoCloseable}. closing the reference
* causes {@link ReplicationTarget#decRef()} to be called. This makes sure that the underlying resources
Expand Down
11 changes: 11 additions & 0 deletions server/src/main/java/org/opensearch/node/Node.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import org.apache.lucene.util.SetOnce;
import org.opensearch.common.util.FeatureFlags;
import org.opensearch.index.IndexingPressureService;
import org.opensearch.indices.replication.SegmentReplicationSourceFactory;
import org.opensearch.indices.replication.SegmentReplicationTargetService;
import org.opensearch.indices.replication.SegmentReplicationSourceService;
import org.opensearch.watcher.ResourceWatcherService;
import org.opensearch.Assertions;
Expand Down Expand Up @@ -936,6 +938,15 @@ protected Node(
b.bind(PeerRecoveryTargetService.class)
.toInstance(new PeerRecoveryTargetService(threadPool, transportService, recoverySettings, clusterService));
if (FeatureFlags.isEnabled(REPLICATION_TYPE)) {
b.bind(SegmentReplicationTargetService.class)
.toInstance(
new SegmentReplicationTargetService(
threadPool,
recoverySettings,
transportService,
new SegmentReplicationSourceFactory(transportService, recoverySettings, clusterService)
)
);
b.bind(SegmentReplicationSourceService.class)
.toInstance(new SegmentReplicationSourceService(indicesService, transportService, recoverySettings));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
package org.opensearch.indices.replication;

import org.junit.Assert;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.opensearch.OpenSearchException;
import org.opensearch.action.ActionListener;
Expand All @@ -18,15 +19,13 @@
import org.opensearch.index.shard.IndexShardTestCase;
import org.opensearch.indices.recovery.RecoverySettings;
import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint;
import org.opensearch.indices.replication.common.ReplicationLuceneIndex;
import org.opensearch.transport.TransportService;

import java.io.IOException;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.when;
import static org.mockito.Mockito.*;

public class SegmentReplicationTargetServiceTests extends IndexShardTestCase {

Expand All @@ -42,7 +41,7 @@ public void setUp() throws Exception {
final ClusterSettings clusterSettings = new ClusterSettings(settings, ClusterSettings.BUILT_IN_CLUSTER_SETTINGS);
final RecoverySettings recoverySettings = new RecoverySettings(settings, clusterSettings);
final TransportService transportService = mock(TransportService.class);
indexShard = newShard(false, settings);
indexShard = newStartedShard(false, settings);
checkpoint = new ReplicationCheckpoint(indexShard.shardId(), 0L, 0L, 0L, 0L);
SegmentReplicationSourceFactory replicationSourceFactory = mock(SegmentReplicationSourceFactory.class);
replicationSource = mock(SegmentReplicationSource.class);
Expand All @@ -57,7 +56,7 @@ public void tearDown() throws Exception {
super.tearDown();
}

public void testTargetReturnsSuccess_listenerCompletes() throws IOException {
public void testTargetReturnsSuccess_listenerCompletes() {
final SegmentReplicationTarget target = new SegmentReplicationTarget(
checkpoint,
indexShard,
Expand All @@ -83,10 +82,9 @@ public void onReplicationFailure(SegmentReplicationState state, OpenSearchExcept
return null;
}).when(spy).startReplication(any());
sut.startReplication(spy);
closeShards(indexShard);
}

public void testTargetThrowsException() throws IOException {
public void testTargetThrowsException() {
final OpenSearchException expectedError = new OpenSearchException("Fail");
final SegmentReplicationTarget target = new SegmentReplicationTarget(
checkpoint,
Expand Down Expand Up @@ -115,10 +113,71 @@ public void onReplicationFailure(SegmentReplicationState state, OpenSearchExcept
return null;
}).when(spy).startReplication(any());
sut.startReplication(spy);
closeShards(indexShard);
}

public void testBeforeIndexShardClosed_CancelsOngoingReplications() throws IOException {
public void testAlreadyOnNewCheckpoint() {
SegmentReplicationTargetService spy = spy(sut);
spy.onNewCheckpoint(indexShard.getLatestReplicationCheckpoint(), indexShard);
verify(spy, times(0)).startReplication(any(), any(), any());
}

public void testShardAlreadyReplicating() {
SegmentReplicationTargetService spy = spy(sut);
// Create a separate target and start it so the shard is already replicating.
final SegmentReplicationTarget target = new SegmentReplicationTarget(
checkpoint,
indexShard,
replicationSource,
mock(SegmentReplicationTargetService.SegmentReplicationListener.class)
);
final SegmentReplicationTarget spyTarget = Mockito.spy(target);
spy.startReplication(spyTarget);

// a new checkpoint comes in for the same IndexShard.
spy.onNewCheckpoint(checkpoint, indexShard);
verify(spy, times(0)).startReplication(any(), any(), any());
spyTarget.markAsDone();
}

public void testNewCheckpointBehindCurrentCheckpoint() {
SegmentReplicationTargetService spy = spy(sut);
spy.onNewCheckpoint(checkpoint, indexShard);
verify(spy, times(0)).startReplication(any(), any(), any());
}

public void testShardNotStarted() throws IOException {
SegmentReplicationTargetService spy = spy(sut);
IndexShard shard = newShard(false);
spy.onNewCheckpoint(checkpoint, shard);
verify(spy, times(0)).startReplication(any(), any(), any());
closeShards(shard);
}

public void testNewCheckpoint_validationPassesAndReplicationFails() throws IOException {
allowShardFailures();
SegmentReplicationTargetService spy = spy(sut);
IndexShard spyShard = spy(indexShard);
ReplicationCheckpoint cp = indexShard.getLatestReplicationCheckpoint();
ReplicationCheckpoint newCheckpoint = new ReplicationCheckpoint(
cp.getShardId(),
cp.getPrimaryTerm(),
cp.getSegmentsGen(),
cp.getSeqNo(),
cp.getSegmentInfosVersion() + 1
);
ArgumentCaptor<SegmentReplicationTargetService.SegmentReplicationListener> captor = ArgumentCaptor.forClass(
SegmentReplicationTargetService.SegmentReplicationListener.class
);
doNothing().when(spy).startReplication(any(), any(), any());
spy.onNewCheckpoint(newCheckpoint, spyShard);
verify(spy, times(1)).startReplication(any(), any(), captor.capture());
SegmentReplicationTargetService.SegmentReplicationListener listener = captor.getValue();
listener.onFailure(new SegmentReplicationState(new ReplicationLuceneIndex()), new OpenSearchException("testing"), true);
verify(spyShard).failShard(any(), any());
closeShard(indexShard, false);
}

public void testBeforeIndexShardClosed_CancelsOngoingReplications() {
final SegmentReplicationTarget target = new SegmentReplicationTarget(
checkpoint,
indexShard,
Expand All @@ -128,7 +187,6 @@ public void testBeforeIndexShardClosed_CancelsOngoingReplications() throws IOExc
final SegmentReplicationTarget spy = Mockito.spy(target);
sut.startReplication(spy);
sut.beforeIndexShardClosed(indexShard.shardId(), indexShard, Settings.EMPTY);
Mockito.verify(spy, times(1)).cancel(any());
closeShards(indexShard);
verify(spy, times(1)).cancel(any());
}
}
Loading