diff --git a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java index 0f4281e21..cd7b657f0 100644 --- a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java @@ -195,6 +195,7 @@ public static KNNMethodContext parse(Object in) { name = (String) value; } else if (PARAMETERS.equals(key)) { if (value == null) { + parameters = null; continue; } diff --git a/src/main/java/org/opensearch/knn/index/MethodComponentContext.java b/src/main/java/org/opensearch/knn/index/MethodComponentContext.java index 73ffcae6e..c43c9f77d 100644 --- a/src/main/java/org/opensearch/knn/index/MethodComponentContext.java +++ b/src/main/java/org/opensearch/knn/index/MethodComponentContext.java @@ -21,6 +21,7 @@ import org.opensearch.index.mapper.MapperParsingException; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; @@ -61,7 +62,15 @@ public MethodComponentContext(String name, Map parameters) { */ public MethodComponentContext(StreamInput in) throws IOException { this.name = in.readString(); - this.parameters = in.readMap(StreamInput::readString, new ParameterMapValueReader()); + + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // do not read if their are no bytes left is null. Make sure this is in sync with the fellow read method. For + // more information, refer to https://github.com/opensearch-project/k-NN/issues/353. + if (in.available() > 0) { + this.parameters = in.readMap(StreamInput::readString, new ParameterMapValueReader()); + } else { + this.parameters = null; + } } /** @@ -94,6 +103,7 @@ public static MethodComponentContext parse(Object in) { name = (String) value; } else if (PARAMETERS.equals(key)) { if (value == null) { + parameters = null; continue; } @@ -129,22 +139,30 @@ public static MethodComponentContext parse(Object in) { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(NAME, name); - builder.startObject(PARAMETERS); - parameters.forEach((key, value) -> { - try { - if (value instanceof MethodComponentContext) { - builder.startObject(key); - ((MethodComponentContext) value).toXContent(builder, params); - builder.endObject(); - } else { - builder.field(key, value); + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // we just create the null field. If parameters are not null, we created a nested structure. For more + // information, refer to https://github.com/opensearch-project/k-NN/issues/353. + if (parameters == null) { + builder.field(PARAMETERS, (String) null); + } else { + builder.startObject(PARAMETERS); + parameters.forEach((key, value) -> { + try { + if (value instanceof MethodComponentContext) { + builder.startObject(key); + ((MethodComponentContext) value).toXContent(builder, params); + builder.endObject(); + } else { + builder.field(key, value); + } + } catch (IOException ioe) { + throw new RuntimeException("Unable to generate xcontent for method component"); } - } catch (IOException ioe) { - throw new RuntimeException("Unable to generate xcontent for method component"); - } - }); - builder.endObject(); + }); + builder.endObject(); + } + return builder; } @@ -180,13 +198,25 @@ public String getName() { * @return parameters */ public Map getParameters() { + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // return an empty map if parameters is null. For more information, refer to + // https://github.com/opensearch-project/k-NN/issues/353. + if (parameters == null) { + return Collections.emptyMap(); + } return parameters; } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(this.name); - out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter()); + + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // do not write if parameters is null. Make sure this is in sync with the fellow read method. For more + // information, refer to https://github.com/opensearch-project/k-NN/issues/353. + if (this.parameters != null) { + out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter()); + } } // Because the generic StreamOutput writeMap method can only write generic values, we need to create a custom one diff --git a/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java b/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java index 826bf3c83..00c87503c 100644 --- a/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java +++ b/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java @@ -47,6 +47,13 @@ public void testStreams() throws IOException { MethodComponentContext copy = new MethodComponentContext(streamOutput.bytes().streamInput()); assertEquals(original, copy); + + // Check that everything works when streams are null + original = new MethodComponentContext(name, null); + streamOutput = new BytesStreamOutput(); + original.writeTo(streamOutput); + copy = new MethodComponentContext(streamOutput.bytes().streamInput()); + assertEquals(original, copy); } /** @@ -105,6 +112,10 @@ public void testGetParameters() throws IOException { MethodComponentContext methodContext = new MethodComponentContext(name, params); assertEquals(paramVal1, methodContext.getParameters().get(paramKey1)); assertEquals(paramVal2, methodContext.getParameters().get(paramKey2)); + + // When parameters are null, an empty map should be returned + methodContext = new MethodComponentContext(name, null); + assertTrue(methodContext.getParameters().isEmpty()); } /** @@ -215,6 +226,15 @@ public void testToXContent() throws IOException { assertEquals(paramVal1, paramMap.get(paramKey1)); assertEquals(paramVal2, paramMap.get(paramKey2)); + + // Check when parameters are null + xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, name).field(PARAMETERS, (String) null).endObject(); + in = xContentBuilderToMap(xContentBuilder); + methodContext = MethodComponentContext.parse(in); + builder = XContentFactory.jsonBuilder().startObject(); + builder = methodContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject(); + out = xContentBuilderToMap(builder); + assertNull(out.get(PARAMETERS)); } public void testEquals() { @@ -229,11 +249,15 @@ public void testEquals() { MethodComponentContext methodContext1 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext2 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext3 = new MethodComponentContext(name2, parameters2); + MethodComponentContext methodContext4 = new MethodComponentContext(name2, null); + MethodComponentContext methodContext5 = new MethodComponentContext(name2, null); assertEquals(methodContext1, methodContext1); assertEquals(methodContext1, methodContext2); assertNotEquals(methodContext1, methodContext3); assertNotEquals(methodContext1, null); + assertNotEquals(methodContext2, methodContext4); + assertEquals(methodContext4, methodContext5); } public void testHashCode() { @@ -248,9 +272,13 @@ public void testHashCode() { MethodComponentContext methodContext1 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext2 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext3 = new MethodComponentContext(name2, parameters2); + MethodComponentContext methodContext4 = new MethodComponentContext(name2, null); + MethodComponentContext methodContext5 = new MethodComponentContext(name2, null); assertEquals(methodContext1.hashCode(), methodContext1.hashCode()); assertEquals(methodContext1.hashCode(), methodContext2.hashCode()); assertNotEquals(methodContext1.hashCode(), methodContext3.hashCode()); + assertNotEquals(methodContext1.hashCode(), methodContext4.hashCode()); + assertEquals(methodContext4.hashCode(), methodContext5.hashCode()); } }