From 825fb2a38cb964d519533db6d8333240f1b3a014 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 10 Nov 2022 16:17:30 -0600 Subject: [PATCH] Add fix to fromXContent and toXContent in ModelGraveyard (#618) * Add fix to fromXContent and toXContent in ModelGraveyard Signed-off-by: Naveen Tatikonda * Address Review Comments Signed-off-by: Naveen Tatikonda Signed-off-by: Naveen Tatikonda (cherry picked from commit f92651bfcbcdf5b5f5443a5338ddc4ba15bb0638) --- .../knn/indices/ModelGraveyard.java | 33 +++++++++++++++++-- .../knn/indices/ModelGraveyardTests.java | 22 +++++++++++++ 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java index ff7232bdf..21f13b6a8 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java +++ b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java @@ -21,7 +21,8 @@ import java.util.EnumSet; import java.util.HashSet; import java.util.Set; -import java.util.stream.Collectors; +import java.util.Iterator; + import com.google.common.collect.Sets; /** @@ -80,6 +81,13 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + Iterator model_ids = getModelIds().iterator(); + + builder.startArray("model_ids"); + while (model_ids.hasNext()) { + builder.value(model_ids.next()); + } + builder.endArray(); return builder; } @@ -143,7 +151,28 @@ public static NamedDiff readDiffFrom(StreamInput streamInput) throws IOException * @throws IOException */ public static ModelGraveyard fromXContent(XContentParser xContentParser) throws IOException { - return new ModelGraveyard(xContentParser.list().stream().map(Object::toString).collect(Collectors.toSet())); + ModelGraveyard modelGraveyard = new ModelGraveyard(); + + // If it is a fresh parser, move to the first token + if (xContentParser.currentToken() == null) { + xContentParser.nextToken(); + } + + // on a start object move to next token + if (xContentParser.currentToken() == XContentParser.Token.START_OBJECT) { + xContentParser.nextToken(); + } + + if (xContentParser.currentToken() != XContentParser.Token.FIELD_NAME) { + throw new IllegalArgumentException("expected field name but got a " + xContentParser.currentToken()); + } + + while (xContentParser.nextToken() != XContentParser.Token.END_OBJECT) { + if (xContentParser.currentToken() == XContentParser.Token.VALUE_STRING) { + modelGraveyard.add(xContentParser.text()); + } + } + return modelGraveyard; } /** diff --git a/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java b/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java index 28b3c7474..694d22ed2 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java @@ -6,6 +6,9 @@ package org.opensearch.knn.indices; import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.xcontent.ToXContent; +import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; @@ -57,6 +60,25 @@ public void testStreams() throws IOException { assertTrue(testModelGraveyardCopy.contains(testModelId)); } + public void testXContentBuilder() throws IOException { + Set modelIds = new HashSet<>(); + String testModelId1 = "test-model-id1"; + String testModelId2 = "test-model-id2"; + modelIds.add(testModelId1); + modelIds.add(testModelId2); + ModelGraveyard testModelGraveyard = new ModelGraveyard(modelIds); + + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + XContentBuilder builder = testModelGraveyard.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + + ModelGraveyard testModelGraveyard2 = ModelGraveyard.fromXContent(createParser(builder)); + assertEquals(2, testModelGraveyard2.size()); + assertTrue(testModelGraveyard2.contains(testModelId1)); + assertTrue(testModelGraveyard2.contains(testModelId2)); + } + public void testDiffStreams() throws IOException { Set added = new HashSet<>(); Set removed = new HashSet<>();