Skip to content

Commit

Permalink
Add fix to fromXContent and toXContent in ModelGraveyard (#618) (#625)
Browse files Browse the repository at this point in the history
* Add fix to fromXContent and toXContent in ModelGraveyard

Signed-off-by: Naveen Tatikonda <[email protected]>

* Address Review Comments

Signed-off-by: Naveen Tatikonda <[email protected]>

Signed-off-by: Naveen Tatikonda <[email protected]>
(cherry picked from commit f92651b)

Co-authored-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
1 parent 1882909 commit 888e696
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 2 deletions.
33 changes: 31 additions & 2 deletions src/main/java/org/opensearch/knn/indices/ModelGraveyard.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
import java.util.EnumSet;
import java.util.HashSet;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.Iterator;

import com.google.common.collect.Sets;

/**
Expand Down Expand Up @@ -80,6 +81,13 @@ public void writeTo(StreamOutput out) throws IOException {

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
Iterator model_ids = getModelIds().iterator();

builder.startArray("model_ids");
while (model_ids.hasNext()) {
builder.value(model_ids.next());
}
builder.endArray();
return builder;
}

Expand Down Expand Up @@ -143,7 +151,28 @@ public static NamedDiff readDiffFrom(StreamInput streamInput) throws IOException
* @throws IOException
*/
public static ModelGraveyard fromXContent(XContentParser xContentParser) throws IOException {
return new ModelGraveyard(xContentParser.list().stream().map(Object::toString).collect(Collectors.toSet()));
ModelGraveyard modelGraveyard = new ModelGraveyard();

// If it is a fresh parser, move to the first token
if (xContentParser.currentToken() == null) {
xContentParser.nextToken();
}

// on a start object move to next token
if (xContentParser.currentToken() == XContentParser.Token.START_OBJECT) {
xContentParser.nextToken();
}

if (xContentParser.currentToken() != XContentParser.Token.FIELD_NAME) {
throw new IllegalArgumentException("expected field name but got a " + xContentParser.currentToken());
}

while (xContentParser.nextToken() != XContentParser.Token.END_OBJECT) {
if (xContentParser.currentToken() == XContentParser.Token.VALUE_STRING) {
modelGraveyard.add(xContentParser.text());
}
}
return modelGraveyard;
}

/**
Expand Down
22 changes: 22 additions & 0 deletions src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
package org.opensearch.knn.indices;

import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
Expand Down Expand Up @@ -57,6 +60,25 @@ public void testStreams() throws IOException {
assertTrue(testModelGraveyardCopy.contains(testModelId));
}

public void testXContentBuilder() throws IOException {
Set<String> modelIds = new HashSet<>();
String testModelId1 = "test-model-id1";
String testModelId2 = "test-model-id2";
modelIds.add(testModelId1);
modelIds.add(testModelId2);
ModelGraveyard testModelGraveyard = new ModelGraveyard(modelIds);

XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
xContentBuilder.startObject();
XContentBuilder builder = testModelGraveyard.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS);
builder.endObject();

ModelGraveyard testModelGraveyard2 = ModelGraveyard.fromXContent(createParser(builder));
assertEquals(2, testModelGraveyard2.size());
assertTrue(testModelGraveyard2.contains(testModelId1));
assertTrue(testModelGraveyard2.contains(testModelId2));
}

public void testDiffStreams() throws IOException {
Set<String> added = new HashSet<>();
Set<String> removed = new HashSet<>();
Expand Down

0 comments on commit 888e696

Please sign in to comment.