From b08127c2c4dcd1844072302e838f44897c8f18f5 Mon Sep 17 00:00:00 2001 From: John Mazanec Date: Tue, 5 Apr 2022 17:25:16 -0400 Subject: [PATCH] Allow null value for params in method mappings (#354) Allow user to input null value for parameters field for KNNMethodContext and MethodComponentContext. Adds tests to reproduce Mapping Parsing Error when the parameters take the value null as well as BWC tests. Signed-off-by: John Mazanec --- .../knn/index/KNNMethodContext.java | 7 +- .../knn/index/MethodComponentContext.java | 72 ++++++++++++++----- .../knn/bwc/KNNBackwardsCompatibilityIT.java | 48 +++++++++++++ .../knn/index/KNNMethodContextTests.java | 15 ++++ .../index/MethodComponentContextTests.java | 43 +++++++++++ 5 files changed, 165 insertions(+), 20 deletions(-) diff --git a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java index 4667640bf..cd7b657f0 100644 --- a/src/main/java/org/opensearch/knn/index/KNNMethodContext.java +++ b/src/main/java/org/opensearch/knn/index/KNNMethodContext.java @@ -42,7 +42,7 @@ */ public class KNNMethodContext implements ToXContentFragment, Writeable { - private static Logger logger = LogManager.getLogger(KNNMethodContext.class); + private static final Logger logger = LogManager.getLogger(KNNMethodContext.class); private static KNNMethodContext defaultInstance = null; @@ -194,6 +194,11 @@ public static KNNMethodContext parse(Object in) { name = (String) value; } else if (PARAMETERS.equals(key)) { + if (value == null) { + parameters = null; + continue; + } + if (!(value instanceof Map)) { throw new MapperParsingException("Unable to parse parameters for main method component"); } diff --git a/src/main/java/org/opensearch/knn/index/MethodComponentContext.java b/src/main/java/org/opensearch/knn/index/MethodComponentContext.java index a8c93fe4c..c43c9f77d 100644 --- a/src/main/java/org/opensearch/knn/index/MethodComponentContext.java +++ b/src/main/java/org/opensearch/knn/index/MethodComponentContext.java @@ -21,6 +21,7 @@ import org.opensearch.index.mapper.MapperParsingException; import java.io.IOException; +import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; @@ -37,10 +38,10 @@ */ public class MethodComponentContext implements ToXContentFragment, Writeable { - private static Logger logger = LogManager.getLogger(MethodComponentContext.class); + private static final Logger logger = LogManager.getLogger(MethodComponentContext.class); - private String name; - private Map parameters; + private final String name; + private final Map parameters; /** * Constructor @@ -61,7 +62,15 @@ public MethodComponentContext(String name, Map parameters) { */ public MethodComponentContext(StreamInput in) throws IOException { this.name = in.readString(); - this.parameters = in.readMap(StreamInput::readString, new ParameterMapValueReader()); + + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // do not read if their are no bytes left is null. Make sure this is in sync with the fellow read method. For + // more information, refer to https://github.com/opensearch-project/k-NN/issues/353. + if (in.available() > 0) { + this.parameters = in.readMap(StreamInput::readString, new ParameterMapValueReader()); + } else { + this.parameters = null; + } } /** @@ -93,6 +102,11 @@ public static MethodComponentContext parse(Object in) { } name = (String) value; } else if (PARAMETERS.equals(key)) { + if (value == null) { + parameters = null; + continue; + } + if (!(value instanceof Map)) { throw new MapperParsingException("Unable to parse parameters for method component"); } @@ -125,22 +139,30 @@ public static MethodComponentContext parse(Object in) { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.field(NAME, name); - builder.startObject(PARAMETERS); - parameters.forEach((key, value) -> { - try { - if (value instanceof MethodComponentContext) { - builder.startObject(key); - ((MethodComponentContext) value).toXContent(builder, params); - builder.endObject(); - } else { - builder.field(key, value); + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // we just create the null field. If parameters are not null, we created a nested structure. For more + // information, refer to https://github.com/opensearch-project/k-NN/issues/353. + if (parameters == null) { + builder.field(PARAMETERS, (String) null); + } else { + builder.startObject(PARAMETERS); + parameters.forEach((key, value) -> { + try { + if (value instanceof MethodComponentContext) { + builder.startObject(key); + ((MethodComponentContext) value).toXContent(builder, params); + builder.endObject(); + } else { + builder.field(key, value); + } + } catch (IOException ioe) { + throw new RuntimeException("Unable to generate xcontent for method component"); } - } catch (IOException ioe) { - throw new RuntimeException("Unable to generate xcontent for method component"); - } - }); - builder.endObject(); + }); + builder.endObject(); + } + return builder; } @@ -176,13 +198,25 @@ public String getName() { * @return parameters */ public Map getParameters() { + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // return an empty map if parameters is null. For more information, refer to + // https://github.com/opensearch-project/k-NN/issues/353. + if (parameters == null) { + return Collections.emptyMap(); + } return parameters; } @Override public void writeTo(StreamOutput out) throws IOException { out.writeString(this.name); - out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter()); + + // Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions, + // do not write if parameters is null. Make sure this is in sync with the fellow read method. For more + // information, refer to https://github.com/opensearch-project/k-NN/issues/353. + if (this.parameters != null) { + out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter()); + } } // Because the generic StreamOutput writeMap method can only write generic values, we need to create a custom one diff --git a/src/test/java/org/opensearch/knn/bwc/KNNBackwardsCompatibilityIT.java b/src/test/java/org/opensearch/knn/bwc/KNNBackwardsCompatibilityIT.java index 54cb54b70..51fd7c25b 100644 --- a/src/test/java/org/opensearch/knn/bwc/KNNBackwardsCompatibilityIT.java +++ b/src/test/java/org/opensearch/knn/bwc/KNNBackwardsCompatibilityIT.java @@ -11,6 +11,8 @@ import java.util.Set; import java.util.stream.Collectors; import org.apache.http.util.EntityUtils; +import org.opensearch.common.Strings; +import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.knn.KNNResult; import org.opensearch.knn.TestUtils; import org.opensearch.knn.index.KNNQueryBuilder; @@ -22,6 +24,10 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.test.rest.OpenSearchRestTestCase; import static org.opensearch.knn.TestUtils.*; +import static org.opensearch.knn.common.KNNConstants.KNN_METHOD; +import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW; +import static org.opensearch.knn.common.KNNConstants.NAME; +import static org.opensearch.knn.common.KNNConstants.PARAMETERS; import static org.opensearch.knn.index.KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY; import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED; @@ -29,6 +35,7 @@ public class KNNBackwardsCompatibilityIT extends KNNRestTestCase { private static final String CLUSTER_NAME = System.getProperty(TEST_CLUSTER_NAME); private final String testIndexName = KNN_BWC_PREFIX + "test-index"; private final String testIndex_Recall = KNN_BWC_PREFIX + "test-index-recall"; + private final String testIndex_NullParams = KNN_BWC_PREFIX + "test-index-null-params"; private final String testFieldName = "test-field"; private final String testField_Recall = "test-field-recall"; private final String testIndex_Recall_Old = KNN_BWC_PREFIX + "test-index-recall-value-old"; @@ -234,6 +241,47 @@ public void testBackwardsCompatibility() throws Exception { } } + public void testNullParametersOnUpgrade() throws Exception { + + // Skip test if version is 1.2 or 1.3 + // systemProperty 'tests.plugin_bwc_version', knn_bwc_version + String bwcVersion = System.getProperty("tests.plugin_bwc_version", null); + if (bwcVersion == null || bwcVersion.startsWith("1.2") || bwcVersion.startsWith("1.3")) { + return; + } + + // Confirm cluster is green before starting + Request waitForGreen = new Request("GET", "/_cluster/health"); + waitForGreen.addParameter("wait_for_nodes", "3"); + waitForGreen.addParameter("wait_for_status", "green"); + client().performRequest(waitForGreen); + + switch (getClusterType()) { + case OLD: + String mapping = Strings.toString( + XContentFactory.jsonBuilder() + .startObject() + .startObject("properties") + .startObject(testFieldName) + .field("type", "knn_vector") + .field("dimension", String.valueOf(dimensions)) + .startObject(KNN_METHOD) + .field(NAME, METHOD_HNSW) + .field(PARAMETERS, (String) null) + .endObject() + .endObject() + .endObject() + .endObject() + ); + + createKnnIndex(testIndex_NullParams, getKNNDefaultIndexSettings(), mapping); + break; + case UPGRADED: + deleteKNNIndex(testIndex_NullParams); + break; + } + } + private String getUri(ClusterType clusterType) { switch (clusterType) { case OLD: diff --git a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java index 902046983..176029ad6 100644 --- a/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java +++ b/src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java @@ -239,6 +239,21 @@ public void testParse_invalid() throws IOException { expectThrows(MapperParsingException.class, () -> MethodComponentContext.parse(in7)); } + /** + * Test context method parsing when parameters are set to null + */ + public void testParse_nullParameters() throws IOException { + String methodName = "test-method"; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, methodName) + .field(PARAMETERS, (String) null) + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + KNNMethodContext knnMethodContext = KNNMethodContext.parse(in); + assertTrue(knnMethodContext.getMethodComponent().getParameters().isEmpty()); + } + /** * Test context method parsing when input is valid */ diff --git a/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java b/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java index cf65dd285..00c87503c 100644 --- a/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java +++ b/src/test/java/org/opensearch/knn/index/MethodComponentContextTests.java @@ -47,6 +47,13 @@ public void testStreams() throws IOException { MethodComponentContext copy = new MethodComponentContext(streamOutput.bytes().streamInput()); assertEquals(original, copy); + + // Check that everything works when streams are null + original = new MethodComponentContext(name, null); + streamOutput = new BytesStreamOutput(); + original.writeTo(streamOutput); + copy = new MethodComponentContext(streamOutput.bytes().streamInput()); + assertEquals(original, copy); } /** @@ -105,6 +112,25 @@ public void testGetParameters() throws IOException { MethodComponentContext methodContext = new MethodComponentContext(name, params); assertEquals(paramVal1, methodContext.getParameters().get(paramKey1)); assertEquals(paramVal2, methodContext.getParameters().get(paramKey2)); + + // When parameters are null, an empty map should be returned + methodContext = new MethodComponentContext(name, null); + assertTrue(methodContext.getParameters().isEmpty()); + } + + /** + * Test method component context parsing when parameters are set to null + */ + public void testParse_nullParameters() throws IOException { + String name = "test-name"; + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .field(NAME, name) + .field(PARAMETERS, (String) null) + .endObject(); + Map in = xContentBuilderToMap(xContentBuilder); + MethodComponentContext methodComponentContext = MethodComponentContext.parse(in); + assertTrue(methodComponentContext.getParameters().isEmpty()); } /** @@ -200,6 +226,15 @@ public void testToXContent() throws IOException { assertEquals(paramVal1, paramMap.get(paramKey1)); assertEquals(paramVal2, paramMap.get(paramKey2)); + + // Check when parameters are null + xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, name).field(PARAMETERS, (String) null).endObject(); + in = xContentBuilderToMap(xContentBuilder); + methodContext = MethodComponentContext.parse(in); + builder = XContentFactory.jsonBuilder().startObject(); + builder = methodContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject(); + out = xContentBuilderToMap(builder); + assertNull(out.get(PARAMETERS)); } public void testEquals() { @@ -214,11 +249,15 @@ public void testEquals() { MethodComponentContext methodContext1 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext2 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext3 = new MethodComponentContext(name2, parameters2); + MethodComponentContext methodContext4 = new MethodComponentContext(name2, null); + MethodComponentContext methodContext5 = new MethodComponentContext(name2, null); assertEquals(methodContext1, methodContext1); assertEquals(methodContext1, methodContext2); assertNotEquals(methodContext1, methodContext3); assertNotEquals(methodContext1, null); + assertNotEquals(methodContext2, methodContext4); + assertEquals(methodContext4, methodContext5); } public void testHashCode() { @@ -233,9 +272,13 @@ public void testHashCode() { MethodComponentContext methodContext1 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext2 = new MethodComponentContext(name1, parameters1); MethodComponentContext methodContext3 = new MethodComponentContext(name2, parameters2); + MethodComponentContext methodContext4 = new MethodComponentContext(name2, null); + MethodComponentContext methodContext5 = new MethodComponentContext(name2, null); assertEquals(methodContext1.hashCode(), methodContext1.hashCode()); assertEquals(methodContext1.hashCode(), methodContext2.hashCode()); assertNotEquals(methodContext1.hashCode(), methodContext3.hashCode()); + assertNotEquals(methodContext1.hashCode(), methodContext4.hashCode()); + assertEquals(methodContext4.hashCode(), methodContext5.hashCode()); } }