diff --git a/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNode.java b/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNode.java index af6b27fef98f1..399618de90e2a 100644 --- a/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNode.java +++ b/server/src/main/java/org/elasticsearch/cluster/routing/RoutingNode.java @@ -25,6 +25,9 @@ import java.util.Objects; import java.util.Set; import java.util.function.Predicate; +import java.util.stream.Stream; + +import static java.util.stream.Collectors.toCollection; /** * A {@link RoutingNode} represents a cluster node associated with a single {@link DiscoveryNode} including all shards @@ -211,29 +214,11 @@ void remove(ShardRouting shard) { /** * Determine the number of shards with a specific state - * @param states set of states which should be counted + * @param state which should be counted * @return number of shards */ - public int numberOfShardsWithState(ShardRoutingState... states) { - if (states.length == 1) { - if (states[0] == ShardRoutingState.INITIALIZING) { - return initializingShards.size(); - } else if (states[0] == ShardRoutingState.RELOCATING) { - return relocatingShards.size(); - } else if (states[0] == ShardRoutingState.STARTED) { - return startedShards.size(); - } - } - - int count = 0; - for (ShardRouting shardEntry : this) { - for (ShardRoutingState state : states) { - if (shardEntry.state() == state) { - count++; - } - } - } - return count; + public int numberOfShardsWithState(ShardRoutingState state) { + return internalGetShardsWithState(state).size(); } /** @@ -242,20 +227,7 @@ public int numberOfShardsWithState(ShardRoutingState... states) { * @return List of shards */ public List shardsWithState(ShardRoutingState state) { - if (state == ShardRoutingState.INITIALIZING) { - return new ArrayList<>(initializingShards); - } else if (state == ShardRoutingState.RELOCATING) { - return new ArrayList<>(relocatingShards); - } else if (state == ShardRoutingState.STARTED) { - return new ArrayList<>(startedShards); - } - List shards = new ArrayList<>(); - for (ShardRouting shardEntry : this) { - if (shardEntry.state() == state) { - shards.add(shardEntry); - } - } - return shards; + return new ArrayList<>(internalGetShardsWithState(state)); } private static final ShardRouting[] EMPTY_SHARD_ROUTING_ARRAY = new ShardRouting[0]; @@ -279,49 +251,28 @@ public ShardRouting[] started() { * @return a list of shards */ public List shardsWithState(String index, ShardRoutingState... states) { - List shards = new ArrayList<>(); - - if (states.length == 1) { - if (states[0] == ShardRoutingState.INITIALIZING) { - for (ShardRouting shardEntry : initializingShards) { - if (shardEntry.getIndexName().equals(index) == false) { - continue; - } - shards.add(shardEntry); - } - return shards; - } else if (states[0] == ShardRoutingState.RELOCATING) { - for (ShardRouting shardEntry : relocatingShards) { - if (shardEntry.getIndexName().equals(index) == false) { - continue; - } - shards.add(shardEntry); - } - return shards; - } else if (states[0] == ShardRoutingState.STARTED) { - for (ShardRouting shardEntry : startedShards) { - if (shardEntry.getIndexName().equals(index) == false) { - continue; - } - shards.add(shardEntry); - } - return shards; - } - } + return Stream.of(states).flatMap(state -> shardsWithState(index, state).stream()).collect(toCollection(ArrayList::new)); + } - for (ShardRouting shardEntry : this) { - if (shardEntry.getIndexName().equals(index) == false) { - continue; - } - for (ShardRoutingState state : states) { - if (shardEntry.state() == state) { - shards.add(shardEntry); - } + public List shardsWithState(String index, ShardRoutingState state) { + var shards = new ArrayList(); + for (ShardRouting shardEntry : internalGetShardsWithState(state)) { + if (shardEntry.getIndexName().equals(index)) { + shards.add(shardEntry); } } return shards; } + private LinkedHashSet internalGetShardsWithState(ShardRoutingState state) { + return switch (state) { + case UNASSIGNED -> throw new IllegalArgumentException("Unassigned shards are not linked to a routing node"); + case INITIALIZING -> initializingShards; + case STARTED -> startedShards; + case RELOCATING -> relocatingShards; + }; + } + /** * The number of shards on this node that will not be eventually relocated. */ diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/RoutingNodeTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/RoutingNodeTests.java index f78c0ff3fe6ba..c40f95f384b28 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/RoutingNodeTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/RoutingNodeTests.java @@ -96,7 +96,6 @@ public void testRemove() { } public void testNumberOfShardsWithState() { - assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.INITIALIZING, ShardRoutingState.STARTED), equalTo(2)); assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.STARTED), equalTo(1)); assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.RELOCATING), equalTo(1)); assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.INITIALIZING), equalTo(1)); diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/FailedShardsRoutingTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/FailedShardsRoutingTests.java index 86ed2badab0e6..5db2ed8de77f2 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/FailedShardsRoutingTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/FailedShardsRoutingTests.java @@ -41,7 +41,6 @@ import static org.elasticsearch.cluster.routing.ShardRoutingState.UNASSIGNED; import static org.hamcrest.Matchers.anyOf; import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -507,10 +506,14 @@ public void testRebalanceFailure() { RoutingNodes routingNodes = clusterState.getRoutingNodes(); assertThat(clusterState.routingTable().index("test").size(), equalTo(2)); - assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2)); - assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED), lessThan(3)); - assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2)); - assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED), lessThan(3)); + assertThat( + routingNodes.node("node1").numberOfShardsWithState(STARTED) + routingNodes.node("node1").numberOfShardsWithState(RELOCATING), + equalTo(2) + ); + assertThat( + routingNodes.node("node2").numberOfShardsWithState(STARTED) + routingNodes.node("node2").numberOfShardsWithState(RELOCATING), + equalTo(2) + ); assertThat(routingNodes.node("node3").numberOfShardsWithState(INITIALIZING), equalTo(1)); logger.info("Fail the shards on node 3"); @@ -521,10 +524,14 @@ public void testRebalanceFailure() { routingNodes = clusterState.getRoutingNodes(); assertThat(clusterState.routingTable().index("test").size(), equalTo(2)); - assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2)); - assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED), lessThan(3)); - assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2)); - assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED), lessThan(3)); + assertThat( + routingNodes.node("node1").numberOfShardsWithState(STARTED) + routingNodes.node("node1").numberOfShardsWithState(RELOCATING), + equalTo(2) + ); + assertThat( + routingNodes.node("node2").numberOfShardsWithState(STARTED) + routingNodes.node("node2").numberOfShardsWithState(RELOCATING), + equalTo(2) + ); if (strategy.isBalancedShardsAllocator()) { assertThat(routingNodes.node("node3").numberOfShardsWithState(INITIALIZING), equalTo(1)); diff --git a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/TenShardsOneReplicaRoutingTests.java b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/TenShardsOneReplicaRoutingTests.java index f03683dd36755..e1134699db625 100644 --- a/server/src/test/java/org/elasticsearch/cluster/routing/allocation/TenShardsOneReplicaRoutingTests.java +++ b/server/src/test/java/org/elasticsearch/cluster/routing/allocation/TenShardsOneReplicaRoutingTests.java @@ -143,10 +143,16 @@ public void testSingleIndexFirstStartPrimaryThenBackups() { routingNodes = clusterState.getRoutingNodes(); assertThat(clusterState.routingTable().index("test").size(), equalTo(10)); - assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED, RELOCATING), equalTo(10)); assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED), lessThan(10)); - assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED, RELOCATING), equalTo(10)); + assertThat( + routingNodes.node("node1").numberOfShardsWithState(STARTED) + routingNodes.node("node1").numberOfShardsWithState(RELOCATING), + equalTo(10) + ); assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED), lessThan(10)); + assertThat( + routingNodes.node("node2").numberOfShardsWithState(STARTED) + routingNodes.node("node2").numberOfShardsWithState(RELOCATING), + equalTo(10) + ); assertThat(routingNodes.node("node3").numberOfShardsWithState(INITIALIZING), equalTo(6)); logger.info("Start the shards on node 3"); diff --git a/test/framework/src/main/java/org/elasticsearch/cluster/routing/RoutingNodesHelper.java b/test/framework/src/main/java/org/elasticsearch/cluster/routing/RoutingNodesHelper.java index 4387f36effa06..95420293e80cf 100644 --- a/test/framework/src/main/java/org/elasticsearch/cluster/routing/RoutingNodesHelper.java +++ b/test/framework/src/main/java/org/elasticsearch/cluster/routing/RoutingNodesHelper.java @@ -21,28 +21,29 @@ private RoutingNodesHelper() {} public static List shardsWithState(RoutingNodes routingNodes, ShardRoutingState state) { List shards = new ArrayList<>(); - for (RoutingNode routingNode : routingNodes) { - shards.addAll(routingNode.shardsWithState(state)); - } if (state == ShardRoutingState.UNASSIGNED) { routingNodes.unassigned().forEach(shards::add); + } else { + for (RoutingNode routingNode : routingNodes) { + shards.addAll(routingNode.shardsWithState(state)); + } } return shards; } - public static List shardsWithState(RoutingNodes routingNodes, String index, ShardRoutingState... state) { + public static List shardsWithState(RoutingNodes routingNodes, String index, ShardRoutingState... states) { List shards = new ArrayList<>(); - for (RoutingNode routingNode : routingNodes) { - shards.addAll(routingNode.shardsWithState(index, state)); - } - for (ShardRoutingState s : state) { - if (s == ShardRoutingState.UNASSIGNED) { + for (ShardRoutingState state : states) { + if (state == ShardRoutingState.UNASSIGNED) { for (ShardRouting unassignedShard : routingNodes.unassigned()) { if (unassignedShard.index().getName().equals(index)) { shards.add(unassignedShard); } } - break; + } else { + for (RoutingNode routingNode : routingNodes) { + shards.addAll(routingNode.shardsWithState(index, state)); + } } } return shards; @@ -64,7 +65,6 @@ public static RoutingNode routingNode(String nodeId, DiscoveryNode node, ShardRo for (ShardRouting shardRouting : shards) { routingNode.add(shardRouting); } - return routingNode; } } diff --git a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherIndexingListenerTests.java b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherIndexingListenerTests.java index c682ca93488d2..985a0948516eb 100644 --- a/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherIndexingListenerTests.java +++ b/x-pack/plugin/watcher/src/test/java/org/elasticsearch/xpack/watcher/WatcherIndexingListenerTests.java @@ -330,7 +330,7 @@ public void testClusterChangedWatchAliasChanged() throws Exception { boolean emptyShards = randomBoolean(); if (emptyShards) { - when(routingNode.shardsWithState(eq(newActiveWatchIndex), any())).thenReturn(Collections.emptyList()); + when(routingNode.shardsWithState(eq(newActiveWatchIndex), any(ShardRoutingState[].class))).thenReturn(Collections.emptyList()); } else { Index index = new Index(newActiveWatchIndex, "uuid"); ShardId shardId = new ShardId(index, 0);