diff --git a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java index 4667640bf..cd7b657f0 100644 --- a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java @@ -42,7 +42,7 @@ */ public class KNNMethodContext implements ToXContentFragment, Writeable { - private static Logger logger = LogManager.getLogger(KNNMethodContext.class); + private static final Logger logger = LogManager.getLogger(KNNMethodContext.class); private static KNNMethodContext defaultInstance = null; @@ -194,6 +194,11 @@ public static KNNMethodContext parse(Object in) { name = (String) value; } else if (PARAMETERS.equals(key)) { + if (value == null) { + parameters = null; + continue; + } + if (!(value instanceof Map)) { throw new MapperParsingException("Unable to parse parameters for main method component"); } diff --git a/src/main/java/org/opensearch/knn/index/MethodComponentContext.java b/src/main/java/org/opensearch/knn/index/MethodComponentContext.java index a8c93fe4c..c43c9f77d 100644 --- a/src/main/java/org/opensearch/knn/index/MethodComponentContext.java +++ b/src/main/java/org/opensearch/knn/index/MethodComponentContext.java @@ -21,6 +21,7 @@ import org.opensearch.index.mapper.MapperParsingException; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; @@ -37,10 +38,10 @@ */ public class MethodComponentContext implements ToXContentFragment, Writeable { - private static Logger logger = LogManager.getLogger(MethodComponentContext.class); + private static final Logger logger = LogManager.getLogger(MethodComponentContext.class); - private String name; - private Map parameters; + private final String name; + private final Map parameters; /** * Constructor @@ -61,7 +62,15 @@ public MethodComponentContext(String name, Map parameters) { */ public MethodComponentContext(StreamInput in) throws IOException { this.name = in.readString(); - this.parameters = in.readMap(StreamInput::readString, new ParameterMapValueReader()); + + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // do not read if their are no bytes left is null. Make sure this is in sync with the fellow read method. For + // more information, refer to https://github.com/opensearch-project/k-NN/issues/353. + if (in.available() > 0) { + this.parameters = in.readMap(StreamInput::readString, new ParameterMapValueReader()); + } else { + this.parameters = null; + } } /** @@ -93,6 +102,11 @@ public static MethodComponentContext parse(Object in) { } name = (String) value; } else if (PARAMETERS.equals(key)) { + if (value == null) { + parameters = null; + continue; + } + if (!(value instanceof Map)) { throw new MapperParsingException("Unable to parse parameters for method component"); } @@ -125,22 +139,30 @@ public static MethodComponentContext parse(Object in) { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(NAME, name); - builder.startObject(PARAMETERS); - parameters.forEach((key, value) -> { - try { - if (value instanceof MethodComponentContext) { - builder.startObject(key); - ((MethodComponentContext) value).toXContent(builder, params); - builder.endObject(); - } else { - builder.field(key, value); + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // we just create the null field. If parameters are not null, we created a nested structure. For more + // information, refer to https://github.com/opensearch-project/k-NN/issues/353. + if (parameters == null) { + builder.field(PARAMETERS, (String) null); + } else { + builder.startObject(PARAMETERS); + parameters.forEach((key, value) -> { + try { + if (value instanceof MethodComponentContext) { + builder.startObject(key); + ((MethodComponentContext) value).toXContent(builder, params); + builder.endObject(); + } else { + builder.field(key, value); + } + } catch (IOException ioe) { + throw new RuntimeException("Unable to generate xcontent for method component"); } - } catch (IOException ioe) { - throw new RuntimeException("Unable to generate xcontent for method component"); - } - }); - builder.endObject(); + }); + builder.endObject(); + } + return builder; } @@ -176,13 +198,25 @@ public String getName() { * @return parameters */ public Map getParameters() { + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // return an empty map if parameters is null. For more information, refer to + // https://github.com/opensearch-project/k-NN/issues/353. + if (parameters == null) { + return Collections.emptyMap(); + } return parameters; } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(this.name); - out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter()); + + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // do not write if parameters is null. Make sure this is in sync with the fellow read method. For more + // information, refer to https://github.com/opensearch-project/k-NN/issues/353. + if (this.parameters != null) { + out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter()); + } } // Because the generic StreamOutput writeMap method can only write generic values, we need to create a custom one diff --git a/src/test/java/org/opensearch/knn/bwc/KNNBackwardsCompatibilityIT.java b/src/test/java/org/opensearch/knn/bwc/KNNBackwardsCompatibilityIT.java index 54cb54b70..51fd7c25b 100644 --- a/src/test/java/org/opensearch/knn/bwc/KNNBackwardsCompatibilityIT.java +++ b/src/test/java/org/opensearch/knn/bwc/KNNBackwardsCompatibilityIT.java @@ -11,6 +11,8 @@ import java.util.Set; import java.util.stream.Collectors; import org.apache.http.util.EntityUtils; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNResult; import org.opensearch.knn.TestUtils; import org.opensearch.knn.index.KNNQueryBuilder; @@ -22,6 +24,10 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.test.rest.OpenSearchRestTestCase; import static org.opensearch.knn.TestUtils.*; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.index.KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY; import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED; @@ -29,6 +35,7 @@ public class KNNBackwardsCompatibilityIT extends KNNRestTestCase { private static final String CLUSTER_NAME = System.getProperty(TEST_CLUSTER_NAME); private final String testIndexName = KNN_BWC_PREFIX + "test-index"; private final String testIndex_Recall = KNN_BWC_PREFIX + "test-index-recall"; + private final String testIndex_NullParams = KNN_BWC_PREFIX + "test-index-null-params"; private final String testFieldName = "test-field"; private final String testField_Recall = "test-field-recall"; private final String testIndex_Recall_Old = KNN_BWC_PREFIX + "test-index-recall-value-old"; @@ -234,6 +241,47 @@ public void testBackwardsCompatibility() throws Exception { } } + public void testNullParametersOnUpgrade() throws Exception { + + // Skip test if version is 1.2 or 1.3 + // systemProperty 'tests.plugin_bwc_version', knn_bwc_version + String bwcVersion = System.getProperty("tests.plugin_bwc_version", null); + if (bwcVersion == null || bwcVersion.startsWith("1.2") || bwcVersion.startsWith("1.3")) { + return; + } + + // Confirm cluster is green before starting + Request waitForGreen = new Request("GET", "/_cluster/health"); + waitForGreen.addParameter("wait_for_nodes", "3"); + waitForGreen.addParameter("wait_for_status", "green"); + client().performRequest(waitForGreen); + + switch (getClusterType()) { + case OLD: + String mapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(testFieldName) + .field("type", "knn_vector") + .field("dimension", String.valueOf(dimensions)) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(PARAMETERS, (String) null) + .endObject() + .endObject() + .endObject() + .endObject() + ); + + createKnnIndex(testIndex_NullParams, getKNNDefaultIndexSettings(), mapping); + break; + case UPGRADED: + deleteKNNIndex(testIndex_NullParams); + break; + } + } + private String getUri(ClusterType clusterType) { switch (clusterType) { case OLD: diff --git a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java index 902046983..176029ad6 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java @@ -239,6 +239,21 @@ public void testParse_invalid() throws IOException { expectThrows(MapperParsingException.class, () -> MethodComponentContext.parse(in7)); } + /** + * Test context method parsing when parameters are set to null + */ + public void testParse_nullParameters() throws IOException { + String methodName = "test-method"; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .field(PARAMETERS, (String) null) + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); + assertTrue(knnMethodContext.getMethodComponent().getParameters().isEmpty()); + } + /** * Test context method parsing when input is valid */ diff --git a/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java b/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java index cf65dd285..00c87503c 100644 --- a/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java +++ b/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java @@ -47,6 +47,13 @@ public void testStreams() throws IOException { MethodComponentContext copy = new MethodComponentContext(streamOutput.bytes().streamInput()); assertEquals(original, copy); + + // Check that everything works when streams are null + original = new MethodComponentContext(name, null); + streamOutput = new BytesStreamOutput(); + original.writeTo(streamOutput); + copy = new MethodComponentContext(streamOutput.bytes().streamInput()); + assertEquals(original, copy); } /** @@ -105,6 +112,25 @@ public void testGetParameters() throws IOException { MethodComponentContext methodContext = new MethodComponentContext(name, params); assertEquals(paramVal1, methodContext.getParameters().get(paramKey1)); assertEquals(paramVal2, methodContext.getParameters().get(paramKey2)); + + // When parameters are null, an empty map should be returned + methodContext = new MethodComponentContext(name, null); + assertTrue(methodContext.getParameters().isEmpty()); + } + + /** + * Test method component context parsing when parameters are set to null + */ + public void testParse_nullParameters() throws IOException { + String name = "test-name"; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, name) + .field(PARAMETERS, (String) null) + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + MethodComponentContext methodComponentContext = MethodComponentContext.parse(in); + assertTrue(methodComponentContext.getParameters().isEmpty()); } /** @@ -200,6 +226,15 @@ public void testToXContent() throws IOException { assertEquals(paramVal1, paramMap.get(paramKey1)); assertEquals(paramVal2, paramMap.get(paramKey2)); + + // Check when parameters are null + xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, name).field(PARAMETERS, (String) null).endObject(); + in = xContentBuilderToMap(xContentBuilder); + methodContext = MethodComponentContext.parse(in); + builder = XContentFactory.jsonBuilder().startObject(); + builder = methodContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject(); + out = xContentBuilderToMap(builder); + assertNull(out.get(PARAMETERS)); } public void testEquals() { @@ -214,11 +249,15 @@ public void testEquals() { MethodComponentContext methodContext1 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext2 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext3 = new MethodComponentContext(name2, parameters2); + MethodComponentContext methodContext4 = new MethodComponentContext(name2, null); + MethodComponentContext methodContext5 = new MethodComponentContext(name2, null); assertEquals(methodContext1, methodContext1); assertEquals(methodContext1, methodContext2); assertNotEquals(methodContext1, methodContext3); assertNotEquals(methodContext1, null); + assertNotEquals(methodContext2, methodContext4); + assertEquals(methodContext4, methodContext5); } public void testHashCode() { @@ -233,9 +272,13 @@ public void testHashCode() { MethodComponentContext methodContext1 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext2 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext3 = new MethodComponentContext(name2, parameters2); + MethodComponentContext methodContext4 = new MethodComponentContext(name2, null); + MethodComponentContext methodContext5 = new MethodComponentContext(name2, null); assertEquals(methodContext1.hashCode(), methodContext1.hashCode()); assertEquals(methodContext1.hashCode(), methodContext2.hashCode()); assertNotEquals(methodContext1.hashCode(), methodContext3.hashCode()); + assertNotEquals(methodContext1.hashCode(), methodContext4.hashCode()); + assertEquals(methodContext4.hashCode(), methodContext5.hashCode()); } }