Skip to content

Commit

Permalink
refactor to avoid coupling
Browse files Browse the repository at this point in the history
Signed-off-by: Poojita Raj <[email protected]>
  • Loading branch information
Poojita-Raj committed Aug 29, 2023
1 parent dc41002 commit 16c4751
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.block.ClusterBlockLevel;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.cluster.routing.Preference;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
Expand All @@ -51,8 +52,6 @@
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;

import static org.opensearch.action.get.TransportGetAction.shouldForcePrimaryRouting;

/**
* Perform the multi get action.
*
Expand All @@ -78,6 +77,10 @@ public TransportMultiGetAction(
this.indexNameExpressionResolver = resolver;
}

protected static boolean shouldForcePrimaryRouting(Metadata metadata, boolean realtime, String preference, String indexName) {
return metadata.isSegmentReplicationEnabled(indexName) && realtime && preference == null;
}

@Override
protected void doExecute(Task task, final MultiGetRequest request, final ActionListener<MultiGetResponse> listener) {
ClusterState clusterState = clusterService.state();
Expand Down Expand Up @@ -112,7 +115,7 @@ protected void doExecute(Task task, final MultiGetRequest request, final ActionL

MultiGetShardRequest shardRequest = shardRequests.get(shardId);
if (shardRequest == null) {
if (shouldForcePrimaryRouting(clusterState.getMetadata(), request.realtime, request.preference, concreteSingleIndex)) {
if (shouldForcePrimaryRouting(clusterState.getMetadata(), request.realtime(), request.preference(), concreteSingleIndex)) {
request.preference(Preference.PRIMARY.type());
}
shardRequest = new MultiGetShardRequest(request, shardId.getIndexName(), shardId.getId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -366,40 +366,37 @@ public ShardRouting activePrimary(ShardId shardId) {
return null;
}

/**
* Returns one active replica shard for the given shard id or <code>null</code> if
* no active replica is found.
*
* Since replicas could possibly be on nodes with an older version of OpenSearch than
* the primary is, this will return replicas on the highest version of OpenSearch when
* document replication strategy is in use, and will return replicas on oldest version
* of OpenSearch when segment replication is enabled.
*
*/
public ShardRouting activeReplicaBasedOnReplicationStrategy(ShardId shardId) {
public ShardRouting activeReplicaWithHighestVersion(ShardId shardId) {
// It's possible for replicaNodeVersion to be null, when disassociating dead nodes
// that have been removed, the shards are failed, and part of the shard failing
// calls this method with an out-of-date RoutingNodes, where the version might not
// be accessible. Therefore, we need to protect against the version being null
// (meaning the node will be going away).
Stream<ShardRouting> candidateShards = assignedShards(shardId).stream()
return assignedShards(shardId).stream()
.filter(shr -> !shr.primary() && shr.active())
.filter(shr -> node(shr.currentNodeId()) != null);
if (metadata.isSegmentReplicationEnabled(shardId.getIndexName())) {
return candidateShards.min(
.filter(shr -> node(shr.currentNodeId()) != null)
.max(
Comparator.comparing(
shr -> node(shr.currentNodeId()).node(),
Comparator.nullsFirst(Comparator.comparing(DiscoveryNode::getVersion))
)
).orElse(null);
)
.orElse(null);
}

}
return candidateShards.max(
Comparator.comparing(
shr -> node(shr.currentNodeId()).node(),
Comparator.nullsFirst(Comparator.comparing(DiscoveryNode::getVersion))
public ShardRouting activeReplicaWithOldestVersion(ShardId shardId) {
// It's possible for replicaNodeVersion to be null. Therefore, we need to protect against the version being null
// (meaning the node will be going away).
return assignedShards(shardId).stream()
.filter(shr -> !shr.primary() && shr.active())
.filter(shr -> node(shr.currentNodeId()) != null)
.min(
Comparator.comparing(
shr -> node(shr.currentNodeId()).node(),
Comparator.nullsFirst(Comparator.comparing(DiscoveryNode::getVersion))
)
)
).orElse(null);
.orElse(null);
}

/**
Expand Down Expand Up @@ -736,7 +733,12 @@ private void unassignPrimaryAndPromoteActiveReplicaIfExists(
RoutingChangesObserver routingChangesObserver
) {
assert failedShard.primary();
ShardRouting activeReplica = activeReplicaBasedOnReplicationStrategy(failedShard.shardId());
ShardRouting activeReplica;
if (metadata.isSegmentReplicationEnabled(failedShard.getIndexName())) {
activeReplica = activeReplicaWithOldestVersion(failedShard.shardId());
} else {
activeReplica = activeReplicaWithHighestVersion(failedShard.shardId());
}
if (activeReplica == null) {
moveToUnassigned(failedShard, unassignedInfo);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.opensearch.cluster.metadata.Metadata;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.routing.OperationRouting;
import org.opensearch.cluster.routing.Preference;
import org.opensearch.cluster.routing.ShardIterator;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Settings;
Expand All @@ -58,6 +59,7 @@
import org.opensearch.core.tasks.TaskId;
import org.opensearch.core.xcontent.MediaTypeRegistry;
import org.opensearch.indices.IndicesService;
import org.opensearch.indices.replication.common.ReplicationType;
import org.opensearch.tasks.Task;
import org.opensearch.tasks.TaskManager;
import org.opensearch.test.OpenSearchTestCase;
Expand All @@ -68,6 +70,7 @@
import org.junit.AfterClass;
import org.junit.BeforeClass;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
Expand All @@ -91,39 +94,16 @@ public class TransportMultiGetActionTests extends OpenSearchTestCase {
private static TransportMultiGetAction transportAction;
private static TransportShardMultiGetAction shardAction;

@BeforeClass
public static void beforeClass() throws Exception {
threadPool = new TestThreadPool(TransportMultiGetActionTests.class.getSimpleName());

transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
boundAddress -> DiscoveryNode.createLocal(
Settings.builder().put("node.name", "node1").build(),
boundAddress.publishAddress(),
randomBase64UUID()
),
null,
emptySet()
) {
@Override
public TaskManager getTaskManager() {
return taskManager;
}
};

final Index index1 = new Index("index1", randomBase64UUID());
final Index index2 = new Index("index2", randomBase64UUID());
final ClusterState clusterState = ClusterState.builder(new ClusterName(TransportMultiGetActionTests.class.getSimpleName()))
private static ClusterState clusterState(ReplicationType replicationType, Index index1, Index index2) throws IOException {
return ClusterState.builder(new ClusterName(TransportMultiGetActionTests.class.getSimpleName()))
.metadata(
new Metadata.Builder().put(
new IndexMetadata.Builder(index1.getName()).settings(
Settings.builder()
.put("index.version.created", Version.CURRENT)
.put("index.number_of_shards", 1)
.put("index.number_of_replicas", 1)
.put(IndexMetadata.SETTING_REPLICATION_TYPE, replicationType)
.put(IndexMetadata.SETTING_INDEX_UUID, index1.getUUID())
)
.putMapping(
Expand All @@ -149,6 +129,7 @@ public TaskManager getTaskManager() {
.put("index.version.created", Version.CURRENT)
.put("index.number_of_shards", 1)
.put("index.number_of_replicas", 1)
.put(IndexMetadata.SETTING_REPLICATION_TYPE, replicationType)
.put(IndexMetadata.SETTING_INDEX_UUID, index1.getUUID())
)
.putMapping(
Expand All @@ -170,6 +151,34 @@ public TaskManager getTaskManager() {
)
)
.build();
}

@BeforeClass
public static void beforeClass() throws Exception {
threadPool = new TestThreadPool(TransportMultiGetActionTests.class.getSimpleName());

transportService = new TransportService(
Settings.EMPTY,
mock(Transport.class),
threadPool,
TransportService.NOOP_TRANSPORT_INTERCEPTOR,
boundAddress -> DiscoveryNode.createLocal(
Settings.builder().put("node.name", "node1").build(),
boundAddress.publishAddress(),
randomBase64UUID()
),
null,
emptySet()
) {
@Override
public TaskManager getTaskManager() {
return taskManager;
}
};

final Index index1 = new Index("index1", randomBase64UUID());
final Index index2 = new Index("index2", randomBase64UUID());
ClusterState clusterState = clusterState(randomBoolean() ? ReplicationType.SEGMENT : ReplicationType.DOCUMENT, index1, index2);

final ShardIterator index1ShardIterator = mock(ShardIterator.class);
when(index1ShardIterator.shardId()).thenReturn(new ShardId(index1, randomInt()));
Expand Down Expand Up @@ -285,6 +294,30 @@ protected void executeShardAction(

}

public void testShouldForcePrimaryRouting() throws IOException {
final Index index1 = new Index("index1", randomBase64UUID());
final Index index2 = new Index("index2", randomBase64UUID());
Metadata metadata = clusterState(ReplicationType.SEGMENT, index1, index2).getMetadata();

// should return false since preference is set for request
assertFalse(TransportMultiGetAction.shouldForcePrimaryRouting(metadata, true, Preference.REPLICA.type(), "index1"));

// should return false since request is not realtime
assertFalse(TransportMultiGetAction.shouldForcePrimaryRouting(metadata, false, null, "index2"));

// should return true since segment replication is enabled
assertTrue(TransportMultiGetAction.shouldForcePrimaryRouting(metadata, true, null, "index1"));

// should return false since index doesn't exist
assertFalse(TransportMultiGetAction.shouldForcePrimaryRouting(metadata, true, null, "index3"));

metadata = clusterState(ReplicationType.DOCUMENT, index1, index2).getMetadata();

// should fail since document replication enabled
assertFalse(TransportGetAction.shouldForcePrimaryRouting(metadata, true, null, "index1"));

}

private static Task createTask() {
return new Task(
randomLong(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ public void testFailAllReplicasInitializingOnPrimaryFail() {
clusterState = startShardsAndReroute(allocation, clusterState, clusterState.getRoutingNodes().shardsWithState(INITIALIZING).get(0));
assertThat(clusterState.getRoutingNodes().shardsWithState(STARTED).size(), equalTo(2));
assertThat(clusterState.getRoutingNodes().shardsWithState(INITIALIZING).size(), equalTo(1));
ShardRouting startedReplica = clusterState.getRoutingNodes().activeReplicaBasedOnReplicationStrategy(shardId);
ShardRouting startedReplica = clusterState.getRoutingNodes().activeReplicaWithHighestVersion(shardId);

// fail the primary shard, check replicas get removed as well...
ShardRouting primaryShardToFail = clusterState.routingTable().index("test").shard(0).primaryShard();
Expand Down Expand Up @@ -726,7 +726,12 @@ private void testReplicaIsPromoted(boolean isSegmentReplicationEnabled) {
assertThat(clusterState.getRoutingNodes().shardsWithState(STARTED).size(), equalTo(4));
assertThat(clusterState.getRoutingNodes().shardsWithState(UNASSIGNED).size(), equalTo(0));

ShardRouting startedReplica = clusterState.getRoutingNodes().activeReplicaBasedOnReplicationStrategy(shardId);
ShardRouting startedReplica;
if (isSegmentReplicationEnabled) {
startedReplica = clusterState.getRoutingNodes().activeReplicaWithOldestVersion(shardId);
} else {
startedReplica = clusterState.getRoutingNodes().activeReplicaWithHighestVersion(shardId);
}
logger.info("--> all shards allocated, replica that should be promoted: {}", startedReplica);

// fail the primary shard again and make sure the correct replica is promoted
Expand Down Expand Up @@ -764,7 +769,11 @@ private void testReplicaIsPromoted(boolean isSegmentReplicationEnabled) {
}
}

startedReplica = clusterState.getRoutingNodes().activeReplicaBasedOnReplicationStrategy(shardId);
if (isSegmentReplicationEnabled) {
startedReplica = clusterState.getRoutingNodes().activeReplicaWithOldestVersion(shardId);
} else {
startedReplica = clusterState.getRoutingNodes().activeReplicaWithHighestVersion(shardId);
}
logger.info("--> failing primary shard a second time, should select: {}", startedReplica);

// fail the primary shard again, and ensure the same thing happens
Expand Down

0 comments on commit 16c4751

Please sign in to comment.