Skip to content

Commit

Permalink
Simplify shardsWithState (#91991)
Browse files Browse the repository at this point in the history
  • Loading branch information
idegtiarenko authored Nov 30, 2022
1 parent 26bc894 commit c895331
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.stream.Stream;

import static java.util.stream.Collectors.toCollection;

/**
* A {@link RoutingNode} represents a cluster node associated with a single {@link DiscoveryNode} including all shards
Expand Down Expand Up @@ -211,29 +214,11 @@ void remove(ShardRouting shard) {

/**
* Determine the number of shards with a specific state
* @param states set of states which should be counted
* @param state which should be counted
* @return number of shards
*/
public int numberOfShardsWithState(ShardRoutingState... states) {
if (states.length == 1) {
if (states[0] == ShardRoutingState.INITIALIZING) {
return initializingShards.size();
} else if (states[0] == ShardRoutingState.RELOCATING) {
return relocatingShards.size();
} else if (states[0] == ShardRoutingState.STARTED) {
return startedShards.size();
}
}

int count = 0;
for (ShardRouting shardEntry : this) {
for (ShardRoutingState state : states) {
if (shardEntry.state() == state) {
count++;
}
}
}
return count;
public int numberOfShardsWithState(ShardRoutingState state) {
return internalGetShardsWithState(state).size();
}

/**
Expand All @@ -242,20 +227,7 @@ public int numberOfShardsWithState(ShardRoutingState... states) {
* @return List of shards
*/
public List<ShardRouting> shardsWithState(ShardRoutingState state) {
if (state == ShardRoutingState.INITIALIZING) {
return new ArrayList<>(initializingShards);
} else if (state == ShardRoutingState.RELOCATING) {
return new ArrayList<>(relocatingShards);
} else if (state == ShardRoutingState.STARTED) {
return new ArrayList<>(startedShards);
}
List<ShardRouting> shards = new ArrayList<>();
for (ShardRouting shardEntry : this) {
if (shardEntry.state() == state) {
shards.add(shardEntry);
}
}
return shards;
return new ArrayList<>(internalGetShardsWithState(state));
}

private static final ShardRouting[] EMPTY_SHARD_ROUTING_ARRAY = new ShardRouting[0];
Expand All @@ -279,49 +251,28 @@ public ShardRouting[] started() {
* @return a list of shards
*/
public List<ShardRouting> shardsWithState(String index, ShardRoutingState... states) {
List<ShardRouting> shards = new ArrayList<>();

if (states.length == 1) {
if (states[0] == ShardRoutingState.INITIALIZING) {
for (ShardRouting shardEntry : initializingShards) {
if (shardEntry.getIndexName().equals(index) == false) {
continue;
}
shards.add(shardEntry);
}
return shards;
} else if (states[0] == ShardRoutingState.RELOCATING) {
for (ShardRouting shardEntry : relocatingShards) {
if (shardEntry.getIndexName().equals(index) == false) {
continue;
}
shards.add(shardEntry);
}
return shards;
} else if (states[0] == ShardRoutingState.STARTED) {
for (ShardRouting shardEntry : startedShards) {
if (shardEntry.getIndexName().equals(index) == false) {
continue;
}
shards.add(shardEntry);
}
return shards;
}
}
return Stream.of(states).flatMap(state -> shardsWithState(index, state).stream()).collect(toCollection(ArrayList::new));
}

for (ShardRouting shardEntry : this) {
if (shardEntry.getIndexName().equals(index) == false) {
continue;
}
for (ShardRoutingState state : states) {
if (shardEntry.state() == state) {
shards.add(shardEntry);
}
public List<ShardRouting> shardsWithState(String index, ShardRoutingState state) {
var shards = new ArrayList<ShardRouting>();
for (ShardRouting shardEntry : internalGetShardsWithState(state)) {
if (shardEntry.getIndexName().equals(index)) {
shards.add(shardEntry);
}
}
return shards;
}

private LinkedHashSet<ShardRouting> internalGetShardsWithState(ShardRoutingState state) {
return switch (state) {
case UNASSIGNED -> throw new IllegalArgumentException("Unassigned shards are not linked to a routing node");
case INITIALIZING -> initializingShards;
case STARTED -> startedShards;
case RELOCATING -> relocatingShards;
};
}

/**
* The number of shards on this node that will not be eventually relocated.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,6 @@ public void testRemove() {
}

public void testNumberOfShardsWithState() {
assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.INITIALIZING, ShardRoutingState.STARTED), equalTo(2));
assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.STARTED), equalTo(1));
assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.RELOCATING), equalTo(1));
assertThat(routingNode.numberOfShardsWithState(ShardRoutingState.INITIALIZING), equalTo(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
import static org.elasticsearch.cluster.routing.ShardRoutingState.UNASSIGNED;
import static org.hamcrest.Matchers.anyOf;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.lessThan;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.nullValue;

Expand Down Expand Up @@ -507,10 +506,14 @@ public void testRebalanceFailure() {
RoutingNodes routingNodes = clusterState.getRoutingNodes();

assertThat(clusterState.routingTable().index("test").size(), equalTo(2));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED), lessThan(3));
assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2));
assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED), lessThan(3));
assertThat(
routingNodes.node("node1").numberOfShardsWithState(STARTED) + routingNodes.node("node1").numberOfShardsWithState(RELOCATING),
equalTo(2)
);
assertThat(
routingNodes.node("node2").numberOfShardsWithState(STARTED) + routingNodes.node("node2").numberOfShardsWithState(RELOCATING),
equalTo(2)
);
assertThat(routingNodes.node("node3").numberOfShardsWithState(INITIALIZING), equalTo(1));

logger.info("Fail the shards on node 3");
Expand All @@ -521,10 +524,14 @@ public void testRebalanceFailure() {
routingNodes = clusterState.getRoutingNodes();

assertThat(clusterState.routingTable().index("test").size(), equalTo(2));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED), lessThan(3));
assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED, RELOCATING), equalTo(2));
assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED), lessThan(3));
assertThat(
routingNodes.node("node1").numberOfShardsWithState(STARTED) + routingNodes.node("node1").numberOfShardsWithState(RELOCATING),
equalTo(2)
);
assertThat(
routingNodes.node("node2").numberOfShardsWithState(STARTED) + routingNodes.node("node2").numberOfShardsWithState(RELOCATING),
equalTo(2)
);

if (strategy.isBalancedShardsAllocator()) {
assertThat(routingNodes.node("node3").numberOfShardsWithState(INITIALIZING), equalTo(1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,16 @@ public void testSingleIndexFirstStartPrimaryThenBackups() {
routingNodes = clusterState.getRoutingNodes();

assertThat(clusterState.routingTable().index("test").size(), equalTo(10));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED, RELOCATING), equalTo(10));
assertThat(routingNodes.node("node1").numberOfShardsWithState(STARTED), lessThan(10));
assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED, RELOCATING), equalTo(10));
assertThat(
routingNodes.node("node1").numberOfShardsWithState(STARTED) + routingNodes.node("node1").numberOfShardsWithState(RELOCATING),
equalTo(10)
);
assertThat(routingNodes.node("node2").numberOfShardsWithState(STARTED), lessThan(10));
assertThat(
routingNodes.node("node2").numberOfShardsWithState(STARTED) + routingNodes.node("node2").numberOfShardsWithState(RELOCATING),
equalTo(10)
);
assertThat(routingNodes.node("node3").numberOfShardsWithState(INITIALIZING), equalTo(6));

logger.info("Start the shards on node 3");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,29 @@ private RoutingNodesHelper() {}

public static List<ShardRouting> shardsWithState(RoutingNodes routingNodes, ShardRoutingState state) {
List<ShardRouting> shards = new ArrayList<>();
for (RoutingNode routingNode : routingNodes) {
shards.addAll(routingNode.shardsWithState(state));
}
if (state == ShardRoutingState.UNASSIGNED) {
routingNodes.unassigned().forEach(shards::add);
} else {
for (RoutingNode routingNode : routingNodes) {
shards.addAll(routingNode.shardsWithState(state));
}
}
return shards;
}

public static List<ShardRouting> shardsWithState(RoutingNodes routingNodes, String index, ShardRoutingState... state) {
public static List<ShardRouting> shardsWithState(RoutingNodes routingNodes, String index, ShardRoutingState... states) {
List<ShardRouting> shards = new ArrayList<>();
for (RoutingNode routingNode : routingNodes) {
shards.addAll(routingNode.shardsWithState(index, state));
}
for (ShardRoutingState s : state) {
if (s == ShardRoutingState.UNASSIGNED) {
for (ShardRoutingState state : states) {
if (state == ShardRoutingState.UNASSIGNED) {
for (ShardRouting unassignedShard : routingNodes.unassigned()) {
if (unassignedShard.index().getName().equals(index)) {
shards.add(unassignedShard);
}
}
break;
} else {
for (RoutingNode routingNode : routingNodes) {
shards.addAll(routingNode.shardsWithState(index, state));
}
}
}
return shards;
Expand All @@ -64,7 +65,6 @@ public static RoutingNode routingNode(String nodeId, DiscoveryNode node, ShardRo
for (ShardRouting shardRouting : shards) {
routingNode.add(shardRouting);
}

return routingNode;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ public void testClusterChangedWatchAliasChanged() throws Exception {
boolean emptyShards = randomBoolean();

if (emptyShards) {
when(routingNode.shardsWithState(eq(newActiveWatchIndex), any())).thenReturn(Collections.emptyList());
when(routingNode.shardsWithState(eq(newActiveWatchIndex), any(ShardRoutingState[].class))).thenReturn(Collections.emptyList());
} else {
Index index = new Index(newActiveWatchIndex, "uuid");
ShardId shardId = new ShardId(index, 0);
Expand Down

0 comments on commit c895331

Please sign in to comment.