diff --git a/server/src/main/java/org/opensearch/action/search/GetAllPitNodesResponse.java b/server/src/main/java/org/opensearch/action/search/GetAllPitNodesResponse.java index 055eb84ab3811..9bb3ab6407696 100644 --- a/server/src/main/java/org/opensearch/action/search/GetAllPitNodesResponse.java +++ b/server/src/main/java/org/opensearch/action/search/GetAllPitNodesResponse.java @@ -41,6 +41,12 @@ public class GetAllPitNodesResponse extends BaseNodesResponse uniquePitIds = new HashSet<>(); + pitInfos.addAll( + getNodes().stream() + .flatMap(p -> p.getPitInfos().stream().filter(t -> uniquePitIds.add(t.getPitId()))) + .collect(Collectors.toList()) + ); } public GetAllPitNodesResponse( diff --git a/server/src/main/java/org/opensearch/action/search/ListPitInfo.java b/server/src/main/java/org/opensearch/action/search/ListPitInfo.java index e120507f4d47a..220b7247517b9 100644 --- a/server/src/main/java/org/opensearch/action/search/ListPitInfo.java +++ b/server/src/main/java/org/opensearch/action/search/ListPitInfo.java @@ -17,6 +17,7 @@ import org.opensearch.core.xcontent.XContentBuilder; import java.io.IOException; +import java.util.Objects; import static org.opensearch.core.xcontent.ConstructingObjectParser.constructorArg; @@ -80,4 +81,17 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + ListPitInfo that = (ListPitInfo) o; + return pitId.equals(that.pitId) && creationTime == that.creationTime && keepAlive == that.keepAlive; + } + + @Override + public int hashCode() { + return Objects.hash(pitId, creationTime, keepAlive); + } + } diff --git a/server/src/test/java/org/opensearch/action/search/GetAllPitNodesResponseTests.java b/server/src/test/java/org/opensearch/action/search/GetAllPitNodesResponseTests.java new file mode 100644 index 0000000000000..882b397575e93 --- /dev/null +++ b/server/src/test/java/org/opensearch/action/search/GetAllPitNodesResponseTests.java @@ -0,0 +1,98 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.action.search; + +import org.opensearch.Version; +import org.opensearch.action.FailedNodeException; +import org.opensearch.cluster.ClusterName; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.core.common.io.stream.NamedWriteableRegistry; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.TransportException; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; + +public class GetAllPitNodesResponseTests extends OpenSearchTestCase { + protected void assertEqualInstances(GetAllPitNodesResponse expected, GetAllPitNodesResponse actual) { + assertNotSame(expected, actual); + Set expectedPitInfos = new HashSet<>(expected.getPitInfos()); + Set actualPitInfos = new HashSet<>(actual.getPitInfos()); + assertEquals(expectedPitInfos, actualPitInfos); + + List expectedResponses = expected.getNodes(); + List actualResponses = actual.getNodes(); + assertEquals(expectedResponses.size(), actualResponses.size()); + for (int i = 0; i < expectedResponses.size(); i++) { + assertEquals(expectedResponses.get(i).getNode(), actualResponses.get(i).getNode()); + Set expectedNodePitInfos = new HashSet<>(expectedResponses.get(i).getPitInfos()); + Set actualNodePitInfos = new HashSet<>(actualResponses.get(i).getPitInfos()); + assertEquals(expectedNodePitInfos, actualNodePitInfos); + } + + List expectedFailures = expected.failures(); + List actualFailures = actual.failures(); + assertEquals(expectedFailures.size(), actualFailures.size()); + for (int i = 0; i < expectedFailures.size(); i++) { + assertEquals(expectedFailures.get(i).nodeId(), actualFailures.get(i).nodeId()); + assertEquals(expectedFailures.get(i).getMessage(), actualFailures.get(i).getMessage()); + assertEquals(expectedFailures.get(i).getCause().getClass(), actualFailures.get(i).getCause().getClass()); + } + } + + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(Collections.emptyList()); + } + + public void testSerialization() throws IOException { + GetAllPitNodesResponse response = createTestItem(); + GetAllPitNodesResponse deserialized = copyWriteable(response, getNamedWriteableRegistry(), GetAllPitNodesResponse::new); + assertEqualInstances(response, deserialized); + } + + private GetAllPitNodesResponse createTestItem() { + int numNodes = randomIntBetween(1, 10); + int numPits = randomInt(10); + List candidatePitInfos = new ArrayList<>(numPits); + for (int i = 0; i < numNodes; i++) { + candidatePitInfos.add(new ListPitInfo(randomAlphaOfLength(10), randomLong(), randomLong())); + } + + List responses = new ArrayList<>(); + List failures = new ArrayList<>(); + for (int i = 0; i < numNodes; i++) { + DiscoveryNode node = new DiscoveryNode( + randomAlphaOfLength(10), + buildNewFakeTransportAddress(), + emptyMap(), + emptySet(), + Version.CURRENT + ); + if (randomBoolean()) { + List nodePitInfos = new ArrayList<>(); + for (int j = 0; j < randomInt(numPits); j++) { + nodePitInfos.add(randomFrom(candidatePitInfos)); + } + responses.add(new GetAllPitNodeResponse(node, nodePitInfos)); + } else { + failures.add( + new FailedNodeException(node.getId(), randomAlphaOfLength(10), new TransportException(randomAlphaOfLength(10))) + ); + } + } + return new GetAllPitNodesResponse(new ClusterName("test"), responses, failures); + } +}