From e9029586f18eba8fca93c308080ea44022ba50e9 Mon Sep 17 00:00:00 2001 From: Naveen Tatikonda Date: Thu, 22 Dec 2022 14:42:02 -0600 Subject: [PATCH] Address Review Comments Signed-off-by: Naveen Tatikonda --- .../knn/indices/ModelGraveyard.java | 40 +++++++++++++------ .../knn/indices/ModelGraveyardTests.java | 18 +++++++++ 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java index a6d92b0840..f064fa6e53 100644 --- a/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java +++ b/src/main/java/org/opensearch/knn/indices/ModelGraveyard.java @@ -154,14 +154,15 @@ public static NamedDiff readDiffFrom(StreamInput streamInput) throws IOException */ public static ModelGraveyard fromXContent(XContentParser xContentParser) throws IOException { // Added validation checks to validate all the different possible scenarios - // model_ids:"abcd" - // {} - // {["abcd", "1234"]} - // {"dummy_field_name":} - // {model_ids:"abcd"} - // {model_ids:null} - // {model_ids:[]} - // {model_ids: ["abcd", "1234"]} + // model_ids:"abcd" - Throws exception as the START_OBJECT token is missing + // {} - Returns an empty ModelGraveyard object (BackwardCompatibility) + // {["abcd", "1234"]} - Throws exception as the FIELD_NAME token is missing + // {"dummy_field_name":} - Throws exception as the FIELD_NAME is not matching with model_ids + // {model_ids:"abcd"} - Throws exception as the START_ARRAY token is missing after field name + // {model_ids:null} - Throws exception as the START_ARRAY token is missing + // {model_ids:[]} - Parses and returns an empty ModelGraveyard object as there are no model ids + // {model_ids: ["abcd", "1234"]} - Parses and returns a ModelGraveyard object which contains the model ids "abcd" and "1234" + // {model_ids:[],dummy_field:[]} - Throws exception as we have FIELD_NAME(dummy_field) instead of END_OBJECT token ModelGraveyard modelGraveyard = new ModelGraveyard(); @@ -173,7 +174,8 @@ public static ModelGraveyard fromXContent(XContentParser xContentParser) throws // Validate if the first token is START_OBJECT if (xContentParser.currentToken() != XContentParser.Token.START_OBJECT) { throw new OpenSearchParseException( - "Unable to parse ModelGraveyard. Expecting START_OBJECT but got " + xContentParser.currentToken() + "Unable to parse ModelGraveyard. Expecting START_OBJECT but got {}", + xContentParser.currentToken() ); } @@ -185,29 +187,41 @@ public static ModelGraveyard fromXContent(XContentParser xContentParser) throws // Validate it starts with FIELD_NAME token after START_OBJECT if (xContentParser.currentToken() != XContentParser.Token.FIELD_NAME) { throw new OpenSearchParseException( - "Unable to parse ModelGraveyard. Expecting FIELD_NAME but got " + xContentParser.currentToken() + "Unable to parse ModelGraveyard. Expecting FIELD_NAME but got {}", + xContentParser.currentToken() ); } // Validating that FIELD_NAME matches with "model_ids" if (!MODEL_IDS.equals(xContentParser.currentName())) { throw new OpenSearchParseException( - "Unable to parse ModelGraveyard. Expecting field model_ids but got " + xContentParser.currentName() + "Unable to parse ModelGraveyard. Expecting field {} but got {}", + MODEL_IDS, + xContentParser.currentName() ); } // Validate it starts with START_ARRAY token after FIELD_NAME if (xContentParser.nextToken() != XContentParser.Token.START_ARRAY) { throw new OpenSearchParseException( - "Unable to parse ModelGraveyard. Expecting START_ARRAY but got " + xContentParser.currentToken() + "Unable to parse ModelGraveyard. Expecting START_ARRAY but got {}", + xContentParser.currentToken() ); } - while (xContentParser.nextToken() != XContentParser.Token.END_OBJECT) { + while (xContentParser.nextToken() != XContentParser.Token.END_ARRAY) { if (xContentParser.currentToken() == XContentParser.Token.VALUE_STRING) { modelGraveyard.add(xContentParser.text()); } } + + // Validate if the last token is END_OBJECT + if (xContentParser.nextToken() != XContentParser.Token.END_OBJECT) { + throw new OpenSearchParseException( + "Unable to parse ModelGraveyard. Expecting END_OBJECT but got {}", + xContentParser.currentToken() + ); + } return modelGraveyard; } diff --git a/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java b/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java index 3c30005f41..3aba56b071 100644 --- a/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java +++ b/src/test/java/org/opensearch/knn/indices/ModelGraveyardTests.java @@ -144,6 +144,24 @@ public void testXContentBuilder6() throws IOException { assertTrue(ex.getMessage().contains("Expecting START_ARRAY but got VALUE_STRING")); } + // Validating {model_ids:["abcd"],model_ids_2:[]} + public void testXContentBuilder7() throws IOException { + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); + xContentBuilder.startObject(); + xContentBuilder.startArray("model_ids"); + xContentBuilder.value("abcd"); + xContentBuilder.endArray(); + xContentBuilder.startArray("model_ids_2"); + xContentBuilder.endArray(); + xContentBuilder.endObject(); + + OpenSearchParseException ex = expectThrows( + OpenSearchParseException.class, + () -> ModelGraveyard.fromXContent(createParser(xContentBuilder)) + ); + assertTrue(ex.getMessage().contains("Expecting END_OBJECT but got FIELD_NAME")); + } + public void testDiffStreams() throws IOException { Set added = new HashSet<>(); Set removed = new HashSet<>();