Skip to content

Commit

Permalink
Add Backward Compatibility and Validation checks to ModelGraveyard XC…
Browse files Browse the repository at this point in the history
…ontent Bugfix (#692)

* Add Backward Compatibility and Validation checks to ModelGraveyard XContent Bugfix

Signed-off-by: Naveen Tatikonda <[email protected]>

* Address Review Comments

Signed-off-by: Naveen Tatikonda <[email protected]>

Signed-off-by: Naveen Tatikonda <[email protected]>
  • Loading branch information
naveentatikonda authored Dec 23, 2022
1 parent c412c8a commit fcd2d55
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 9 deletions.
70 changes: 62 additions & 8 deletions src/main/java/org/opensearch/knn/indices/ModelGraveyard.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.opensearch.OpenSearchParseException;
import org.opensearch.Version;
import org.opensearch.cluster.Diff;
import org.opensearch.cluster.NamedDiff;
Expand Down Expand Up @@ -35,6 +36,7 @@
@Log4j2
public class ModelGraveyard implements Metadata.Custom {
public static final String TYPE = "opensearch-knn-blocked-models";
private static final String MODEL_IDS = "model_ids";
private final Set<String> modelIds;

/**
Expand Down Expand Up @@ -83,7 +85,7 @@ public void writeTo(StreamOutput out) throws IOException {
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
Iterator model_ids = getModelIds().iterator();

builder.startArray("model_ids");
builder.startArray(MODEL_IDS);
while (model_ids.hasNext()) {
builder.value(model_ids.next());
}
Expand Down Expand Up @@ -151,26 +153,78 @@ public static NamedDiff readDiffFrom(StreamInput streamInput) throws IOException
* @throws IOException
*/
public static ModelGraveyard fromXContent(XContentParser xContentParser) throws IOException {
// Added validation checks to validate all the different possible scenarios
// model_ids:"abcd" - Throws exception as the START_OBJECT token is missing
// {} - Returns an empty ModelGraveyard object (BackwardCompatibility)
// {["abcd", "1234"]} - Throws exception as the FIELD_NAME token is missing
// {"dummy_field_name":} - Throws exception as the FIELD_NAME is not matching with model_ids
// {model_ids:"abcd"} - Throws exception as the START_ARRAY token is missing after field name
// {model_ids:null} - Throws exception as the START_ARRAY token is missing
// {model_ids:[]} - Parses and returns an empty ModelGraveyard object as there are no model ids
// {model_ids: ["abcd", "1234"]} - Parses and returns a ModelGraveyard object which contains the model ids "abcd" and "1234"
// {model_ids:[],dummy_field:[]} - Throws exception as we have FIELD_NAME(dummy_field) instead of END_OBJECT token

ModelGraveyard modelGraveyard = new ModelGraveyard();

// If it is a fresh parser, move to the first token
if (xContentParser.currentToken() == null) {
xContentParser.nextToken();
}

// on a start object move to next token
if (xContentParser.currentToken() == XContentParser.Token.START_OBJECT) {
xContentParser.nextToken();
// Validate if the first token is START_OBJECT
if (xContentParser.currentToken() != XContentParser.Token.START_OBJECT) {
throw new OpenSearchParseException(
"Unable to parse ModelGraveyard. Expecting token start of an object but got {}",
xContentParser.currentToken()
);
}

// Adding Backward Compatibility for the domains that have already parsed the old toXContent logic which has XContent as {}
if (xContentParser.nextToken() == XContentParser.Token.END_OBJECT) {
return modelGraveyard;
}

// Validate it starts with FIELD_NAME token after START_OBJECT
if (xContentParser.currentToken() != XContentParser.Token.FIELD_NAME) {
throw new IllegalArgumentException("expected field name but got a " + xContentParser.currentToken());
throw new OpenSearchParseException(
"Unable to parse ModelGraveyard. Expecting token field name but got {}",
xContentParser.currentToken()
);
}

// Validating that FIELD_NAME matches with "model_ids"
if (!MODEL_IDS.equals(xContentParser.currentName())) {
throw new OpenSearchParseException(
"Unable to parse ModelGraveyard. Expecting field {} but got {}",
MODEL_IDS,
xContentParser.currentName()
);
}

while (xContentParser.nextToken() != XContentParser.Token.END_OBJECT) {
if (xContentParser.currentToken() == XContentParser.Token.VALUE_STRING) {
modelGraveyard.add(xContentParser.text());
// Validate it starts with START_ARRAY token after FIELD_NAME
if (xContentParser.nextToken() != XContentParser.Token.START_ARRAY) {
throw new OpenSearchParseException(
"Unable to parse ModelGraveyard. Expecting token start of an array but got {}",
xContentParser.currentToken()
);
}

while (xContentParser.nextToken() != XContentParser.Token.END_ARRAY) {
if (xContentParser.currentToken() != XContentParser.Token.VALUE_STRING) {
throw new OpenSearchParseException(
"Unable to parse ModelGraveyard. Expecting token value string but got {}",
xContentParser.currentToken()
);
}
modelGraveyard.add(xContentParser.text());
}

// Validate if the last token is END_OBJECT
if (xContentParser.nextToken() != XContentParser.Token.END_OBJECT) {
throw new OpenSearchParseException(
"Unable to parse ModelGraveyard. Expecting token end of an object but got {}",
xContentParser.currentToken()
);
}
return modelGraveyard;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.indices;

import lombok.SneakyThrows;
import org.opensearch.OpenSearchParseException;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
Expand Down Expand Up @@ -60,7 +62,9 @@ public void testStreams() throws IOException {
assertTrue(testModelGraveyardCopy.contains(testModelId));
}

public void testXContentBuilder() throws IOException {
// Validating {model_ids: ["test-model-id1", "test-model-id2"]}
@SneakyThrows
public void testXContentBuilder_withModelIds_returnsModelGraveyardWithModelIds() {
Set<String> modelIds = new HashSet<>();
String testModelId1 = "test-model-id1";
String testModelId2 = "test-model-id2";
Expand All @@ -79,6 +83,93 @@ public void testXContentBuilder() throws IOException {
assertTrue(testModelGraveyard2.contains(testModelId2));
}

// Validating {model_ids:[]}
@SneakyThrows
public void testXContentBuilder_withoutModelIds_returnsModelGraveyardWithoutModelIds() {
ModelGraveyard testModelGraveyard = new ModelGraveyard();
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
xContentBuilder.startObject();
XContentBuilder builder = testModelGraveyard.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS);
builder.endObject();

ModelGraveyard testModelGraveyard2 = ModelGraveyard.fromXContent(createParser(builder));
assertEquals(0, testModelGraveyard2.size());
}

// Validating {test-model:"abcd"}
@SneakyThrows
public void testXContentBuilder_withWrongFieldName_throwsException() {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
xContentBuilder.startObject();
xContentBuilder.field("test-model");
xContentBuilder.value("abcd");
xContentBuilder.endObject();

OpenSearchParseException ex = expectThrows(
OpenSearchParseException.class,
() -> ModelGraveyard.fromXContent(createParser(xContentBuilder))
);
assertTrue(ex.getMessage().contains("Expecting field model_ids but got test-model"));
}

// Validating {}
@SneakyThrows
public void testXContentBuilder_validateBackwardCompatibility_returnsEmptyModelGraveyardObject() {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
xContentBuilder.startObject();
xContentBuilder.endObject();

ModelGraveyard testModelGraveyard = ModelGraveyard.fromXContent(createParser(xContentBuilder));
assertEquals(0, testModelGraveyard.size());
}

// Validating null
@SneakyThrows
public void testXContentBuilder_withNull_throwsExceptionExpectingStartObject() {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();

OpenSearchParseException ex = expectThrows(
OpenSearchParseException.class,
() -> ModelGraveyard.fromXContent(createParser(xContentBuilder))
);
assertTrue(ex.getMessage().contains("Expecting token start of an object but got null"));
}

// Validating {model_ids:"abcd"}
@SneakyThrows
public void testXContentBuilder_withMissingStartArray_throwsExceptionExpectingStartArray() {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
xContentBuilder.startObject();
xContentBuilder.field("model_ids");
xContentBuilder.value("abcd");
xContentBuilder.endObject();

OpenSearchParseException ex = expectThrows(
OpenSearchParseException.class,
() -> ModelGraveyard.fromXContent(createParser(xContentBuilder))
);
assertTrue(ex.getMessage().contains("Expecting token start of an array but got VALUE_STRING"));
}

// Validating {model_ids:["abcd"],model_ids_2:[]}
@SneakyThrows
public void testXContentBuilder_validateEndObject_throwsExceptionGotFieldName() {
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder();
xContentBuilder.startObject();
xContentBuilder.startArray("model_ids");
xContentBuilder.value("abcd");
xContentBuilder.endArray();
xContentBuilder.startArray("model_ids_2");
xContentBuilder.endArray();
xContentBuilder.endObject();

OpenSearchParseException ex = expectThrows(
OpenSearchParseException.class,
() -> ModelGraveyard.fromXContent(createParser(xContentBuilder))
);
assertTrue(ex.getMessage().contains("Expecting token end of an object but got FIELD_NAME"));
}

public void testDiffStreams() throws IOException {
Set<String> added = new HashSet<>();
Set<String> removed = new HashSet<>();
Expand Down

0 comments on commit fcd2d55

Please sign in to comment.