From fcd2d551c57aab35958516d3198265fe95c906a2 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 22 Dec 2022 18:01:29 -0600 Subject: [PATCH] Add Backward Compatibility and Validation checks to ModelGraveyard XContent Bugfix (#692) * Add Backward Compatibility and Validation checks to ModelGraveyard XContent Bugfix Signed-off-by: Naveen Tatikonda * Address Review Comments Signed-off-by: Naveen Tatikonda Signed-off-by: Naveen Tatikonda --- .../knn/indices/ModelGraveyard.java | 70 ++++++++++++-- .../knn/indices/ModelGraveyardTests.java | 93 ++++++++++++++++++- 2 files changed, 154 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java index 21f13b6a8..5db6d77fd 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java +++ b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java @@ -7,6 +7,7 @@ import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; +import org.opensearch.OpenSearchParseException; import org.opensearch.Version; import org.opensearch.cluster.Diff; import org.opensearch.cluster.NamedDiff; @@ -35,6 +36,7 @@ @Log4j2 public class ModelGraveyard implements Metadata.Custom { public static final String TYPE = "opensearch-knn-blocked-models"; + private static final String MODEL_IDS = "model_ids"; private final Set modelIds; /** @@ -83,7 +85,7 @@ public void writeTo(StreamOutput out) throws IOException { public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { Iterator model_ids = getModelIds().iterator(); - builder.startArray("model_ids"); + builder.startArray(MODEL_IDS); while (model_ids.hasNext()) { builder.value(model_ids.next()); } @@ -151,6 +153,17 @@ public static NamedDiff readDiffFrom(StreamInput streamInput) throws IOException * @throws IOException */ public static ModelGraveyard fromXContent(XContentParser xContentParser) throws IOException { + // Added validation checks to validate all the different possible scenarios + // model_ids:"abcd" - Throws exception as the START_OBJECT token is missing + // {} - Returns an empty ModelGraveyard object (BackwardCompatibility) + // {["abcd", "1234"]} - Throws exception as the FIELD_NAME token is missing + // {"dummy_field_name":} - Throws exception as the FIELD_NAME is not matching with model_ids + // {model_ids:"abcd"} - Throws exception as the START_ARRAY token is missing after field name + // {model_ids:null} - Throws exception as the START_ARRAY token is missing + // {model_ids:[]} - Parses and returns an empty ModelGraveyard object as there are no model ids + // {model_ids: ["abcd", "1234"]} - Parses and returns a ModelGraveyard object which contains the model ids "abcd" and "1234" + // {model_ids:[],dummy_field:[]} - Throws exception as we have FIELD_NAME(dummy_field) instead of END_OBJECT token + ModelGraveyard modelGraveyard = new ModelGraveyard(); // If it is a fresh parser, move to the first token @@ -158,19 +171,60 @@ public static ModelGraveyard fromXContent(XContentParser xContentParser) throws xContentParser.nextToken(); } - // on a start object move to next token - if (xContentParser.currentToken() == XContentParser.Token.START_OBJECT) { - xContentParser.nextToken(); + // Validate if the first token is START_OBJECT + if (xContentParser.currentToken() != XContentParser.Token.START_OBJECT) { + throw new OpenSearchParseException( + "Unable to parse ModelGraveyard. Expecting token start of an object but got {}", + xContentParser.currentToken() + ); + } + + // Adding Backward Compatibility for the domains that have already parsed the old toXContent logic which has XContent as {} + if (xContentParser.nextToken() == XContentParser.Token.END_OBJECT) { + return modelGraveyard; } + // Validate it starts with FIELD_NAME token after START_OBJECT if (xContentParser.currentToken() != XContentParser.Token.FIELD_NAME) { - throw new IllegalArgumentException("expected field name but got a " + xContentParser.currentToken()); + throw new OpenSearchParseException( + "Unable to parse ModelGraveyard. Expecting token field name but got {}", + xContentParser.currentToken() + ); + } + + // Validating that FIELD_NAME matches with "model_ids" + if (!MODEL_IDS.equals(xContentParser.currentName())) { + throw new OpenSearchParseException( + "Unable to parse ModelGraveyard. Expecting field {} but got {}", + MODEL_IDS, + xContentParser.currentName() + ); } - while (xContentParser.nextToken() != XContentParser.Token.END_OBJECT) { - if (xContentParser.currentToken() == XContentParser.Token.VALUE_STRING) { - modelGraveyard.add(xContentParser.text()); + // Validate it starts with START_ARRAY token after FIELD_NAME + if (xContentParser.nextToken() != XContentParser.Token.START_ARRAY) { + throw new OpenSearchParseException( + "Unable to parse ModelGraveyard. Expecting token start of an array but got {}", + xContentParser.currentToken() + ); + } + + while (xContentParser.nextToken() != XContentParser.Token.END_ARRAY) { + if (xContentParser.currentToken() != XContentParser.Token.VALUE_STRING) { + throw new OpenSearchParseException( + "Unable to parse ModelGraveyard. Expecting token value string but got {}", + xContentParser.currentToken() + ); } + modelGraveyard.add(xContentParser.text()); + } + + // Validate if the last token is END_OBJECT + if (xContentParser.nextToken() != XContentParser.Token.END_OBJECT) { + throw new OpenSearchParseException( + "Unable to parse ModelGraveyard. Expecting token end of an object but got {}", + xContentParser.currentToken() + ); } return modelGraveyard; } diff --git a/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java b/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java index 694d22ed2..98dcf0cea 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java @@ -5,6 +5,8 @@ package org.opensearch.knn.indices; +import lombok.SneakyThrows; +import org.opensearch.OpenSearchParseException; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.ToXContent; import org.opensearch.common.xcontent.XContentBuilder; @@ -60,7 +62,9 @@ public void testStreams() throws IOException { assertTrue(testModelGraveyardCopy.contains(testModelId)); } - public void testXContentBuilder() throws IOException { + // Validating {model_ids: ["test-model-id1", "test-model-id2"]} + @SneakyThrows + public void testXContentBuilder_withModelIds_returnsModelGraveyardWithModelIds() { Set modelIds = new HashSet<>(); String testModelId1 = "test-model-id1"; String testModelId2 = "test-model-id2"; @@ -79,6 +83,93 @@ public void testXContentBuilder() throws IOException { assertTrue(testModelGraveyard2.contains(testModelId2)); } + // Validating {model_ids:[]} + @SneakyThrows + public void testXContentBuilder_withoutModelIds_returnsModelGraveyardWithoutModelIds() { + ModelGraveyard testModelGraveyard = new ModelGraveyard(); + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + XContentBuilder builder = testModelGraveyard.toXContent(xContentBuilder, ToXContent.EMPTY_PARAMS); + builder.endObject(); + + ModelGraveyard testModelGraveyard2 = ModelGraveyard.fromXContent(createParser(builder)); + assertEquals(0, testModelGraveyard2.size()); + } + + // Validating {test-model:"abcd"} + @SneakyThrows + public void testXContentBuilder_withWrongFieldName_throwsException() { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + xContentBuilder.field("test-model"); + xContentBuilder.value("abcd"); + xContentBuilder.endObject(); + + OpenSearchParseException ex = expectThrows( + OpenSearchParseException.class, + () -> ModelGraveyard.fromXContent(createParser(xContentBuilder)) + ); + assertTrue(ex.getMessage().contains("Expecting field model_ids but got test-model")); + } + + // Validating {} + @SneakyThrows + public void testXContentBuilder_validateBackwardCompatibility_returnsEmptyModelGraveyardObject() { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + xContentBuilder.endObject(); + + ModelGraveyard testModelGraveyard = ModelGraveyard.fromXContent(createParser(xContentBuilder)); + assertEquals(0, testModelGraveyard.size()); + } + + // Validating null + @SneakyThrows + public void testXContentBuilder_withNull_throwsExceptionExpectingStartObject() { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + + OpenSearchParseException ex = expectThrows( + OpenSearchParseException.class, + () -> ModelGraveyard.fromXContent(createParser(xContentBuilder)) + ); + assertTrue(ex.getMessage().contains("Expecting token start of an object but got null")); + } + + // Validating {model_ids:"abcd"} + @SneakyThrows + public void testXContentBuilder_withMissingStartArray_throwsExceptionExpectingStartArray() { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + xContentBuilder.field("model_ids"); + xContentBuilder.value("abcd"); + xContentBuilder.endObject(); + + OpenSearchParseException ex = expectThrows( + OpenSearchParseException.class, + () -> ModelGraveyard.fromXContent(createParser(xContentBuilder)) + ); + assertTrue(ex.getMessage().contains("Expecting token start of an array but got VALUE_STRING")); + } + + // Validating {model_ids:["abcd"],model_ids_2:[]} + @SneakyThrows + public void testXContentBuilder_validateEndObject_throwsExceptionGotFieldName() { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + xContentBuilder.startArray("model_ids"); + xContentBuilder.value("abcd"); + xContentBuilder.endArray(); + xContentBuilder.startArray("model_ids_2"); + xContentBuilder.endArray(); + xContentBuilder.endObject(); + + OpenSearchParseException ex = expectThrows( + OpenSearchParseException.class, + () -> ModelGraveyard.fromXContent(createParser(xContentBuilder)) + ); + assertTrue(ex.getMessage().contains("Expecting token end of an object but got FIELD_NAME")); + } + public void testDiffStreams() throws IOException { Set added = new HashSet<>(); Set removed = new HashSet<>();