Skip to content

Commit

Permalink
Allow null value for params in method mappings (opensearch-project#354)
Browse files Browse the repository at this point in the history
Allow user to input null value for parameters field for KNNMethodContext
and MethodComponentContext. Adds tests to reproduce Mapping Parsing Error when the parameters take
the value null as well as BWC tests.

Signed-off-by: John Mazanec <[email protected]>
  • Loading branch information
jmazanec15 authored Apr 5, 2022
1 parent d9a9f59 commit b08127c
Show file tree
Hide file tree
Showing 5 changed files with 165 additions and 20 deletions.
7 changes: 6 additions & 1 deletion src/main/java/org/opensearch/knn/index/KNNMethodContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
*/
public class KNNMethodContext implements ToXContentFragment, Writeable {

private static Logger logger = LogManager.getLogger(KNNMethodContext.class);
private static final Logger logger = LogManager.getLogger(KNNMethodContext.class);

private static KNNMethodContext defaultInstance = null;

Expand Down Expand Up @@ -194,6 +194,11 @@ public static KNNMethodContext parse(Object in) {

name = (String) value;
} else if (PARAMETERS.equals(key)) {
if (value == null) {
parameters = null;
continue;
}

if (!(value instanceof Map)) {
throw new MapperParsingException("Unable to parse parameters for main method component");
}
Expand Down
72 changes: 53 additions & 19 deletions src/main/java/org/opensearch/knn/index/MethodComponentContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.index.mapper.MapperParsingException;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
Expand All @@ -37,10 +38,10 @@
*/
public class MethodComponentContext implements ToXContentFragment, Writeable {

private static Logger logger = LogManager.getLogger(MethodComponentContext.class);
private static final Logger logger = LogManager.getLogger(MethodComponentContext.class);

private String name;
private Map<String, Object> parameters;
private final String name;
private final Map<String, Object> parameters;

/**
* Constructor
Expand All @@ -61,7 +62,15 @@ public MethodComponentContext(String name, Map<String, Object> parameters) {
*/
public MethodComponentContext(StreamInput in) throws IOException {
this.name = in.readString();
this.parameters = in.readMap(StreamInput::readString, new ParameterMapValueReader());

// Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions,
// do not read if their are no bytes left is null. Make sure this is in sync with the fellow read method. For
// more information, refer to https://github.com/opensearch-project/k-NN/issues/353.
if (in.available() > 0) {
this.parameters = in.readMap(StreamInput::readString, new ParameterMapValueReader());
} else {
this.parameters = null;
}
}

/**
Expand Down Expand Up @@ -93,6 +102,11 @@ public static MethodComponentContext parse(Object in) {
}
name = (String) value;
} else if (PARAMETERS.equals(key)) {
if (value == null) {
parameters = null;
continue;
}

if (!(value instanceof Map)) {
throw new MapperParsingException("Unable to parse parameters for method component");
}
Expand Down Expand Up @@ -125,22 +139,30 @@ public static MethodComponentContext parse(Object in) {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(NAME, name);
builder.startObject(PARAMETERS);
parameters.forEach((key, value) -> {
try {
if (value instanceof MethodComponentContext) {
builder.startObject(key);
((MethodComponentContext) value).toXContent(builder, params);
builder.endObject();
} else {
builder.field(key, value);
// Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions,
// we just create the null field. If parameters are not null, we created a nested structure. For more
// information, refer to https://github.com/opensearch-project/k-NN/issues/353.
if (parameters == null) {
builder.field(PARAMETERS, (String) null);
} else {
builder.startObject(PARAMETERS);
parameters.forEach((key, value) -> {
try {
if (value instanceof MethodComponentContext) {
builder.startObject(key);
((MethodComponentContext) value).toXContent(builder, params);
builder.endObject();
} else {
builder.field(key, value);
}
} catch (IOException ioe) {
throw new RuntimeException("Unable to generate xcontent for method component");
}
} catch (IOException ioe) {
throw new RuntimeException("Unable to generate xcontent for method component");
}

});
builder.endObject();
});
builder.endObject();
}

return builder;
}

Expand Down Expand Up @@ -176,13 +198,25 @@ public String getName() {
* @return parameters
*/
public Map<String, Object> getParameters() {
// Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions,
// return an empty map if parameters is null. For more information, refer to
// https://github.com/opensearch-project/k-NN/issues/353.
if (parameters == null) {
return Collections.emptyMap();
}
return parameters;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.name);
out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter());

// Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions,
// do not write if parameters is null. Make sure this is in sync with the fellow read method. For more
// information, refer to https://github.com/opensearch-project/k-NN/issues/353.
if (this.parameters != null) {
out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter());
}
}

// Because the generic StreamOutput writeMap method can only write generic values, we need to create a custom one
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.http.util.EntityUtils;
import org.opensearch.common.Strings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.index.KNNQueryBuilder;
Expand All @@ -22,13 +24,18 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.test.rest.OpenSearchRestTestCase;
import static org.opensearch.knn.TestUtils.*;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.index.KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY;
import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED;

public class KNNBackwardsCompatibilityIT extends KNNRestTestCase {
private static final String CLUSTER_NAME = System.getProperty(TEST_CLUSTER_NAME);
private final String testIndexName = KNN_BWC_PREFIX + "test-index";
private final String testIndex_Recall = KNN_BWC_PREFIX + "test-index-recall";
private final String testIndex_NullParams = KNN_BWC_PREFIX + "test-index-null-params";
private final String testFieldName = "test-field";
private final String testField_Recall = "test-field-recall";
private final String testIndex_Recall_Old = KNN_BWC_PREFIX + "test-index-recall-value-old";
Expand Down Expand Up @@ -234,6 +241,47 @@ public void testBackwardsCompatibility() throws Exception {
}
}

public void testNullParametersOnUpgrade() throws Exception {

// Skip test if version is 1.2 or 1.3
// systemProperty 'tests.plugin_bwc_version', knn_bwc_version
String bwcVersion = System.getProperty("tests.plugin_bwc_version", null);
if (bwcVersion == null || bwcVersion.startsWith("1.2") || bwcVersion.startsWith("1.3")) {
return;
}

// Confirm cluster is green before starting
Request waitForGreen = new Request("GET", "/_cluster/health");
waitForGreen.addParameter("wait_for_nodes", "3");
waitForGreen.addParameter("wait_for_status", "green");
client().performRequest(waitForGreen);

switch (getClusterType()) {
case OLD:
String mapping = Strings.toString(
XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(testFieldName)
.field("type", "knn_vector")
.field("dimension", String.valueOf(dimensions))
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(PARAMETERS, (String) null)
.endObject()
.endObject()
.endObject()
.endObject()
);

createKnnIndex(testIndex_NullParams, getKNNDefaultIndexSettings(), mapping);
break;
case UPGRADED:
deleteKNNIndex(testIndex_NullParams);
break;
}
}

private String getUri(ClusterType clusterType) {
switch (clusterType) {
case OLD:
Expand Down
15 changes: 15 additions & 0 deletions src/test/java/org/opensearch/knn/index/KNNMethodContextTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,21 @@ public void testParse_invalid() throws IOException {
expectThrows(MapperParsingException.class, () -> MethodComponentContext.parse(in7));
}

/**
* Test context method parsing when parameters are set to null
*/
public void testParse_nullParameters() throws IOException {
String methodName = "test-method";
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, methodName)
.field(PARAMETERS, (String) null)
.endObject();
Map<String, Object> in = xContentBuilderToMap(xContentBuilder);
KNNMethodContext knnMethodContext = KNNMethodContext.parse(in);
assertTrue(knnMethodContext.getMethodComponent().getParameters().isEmpty());
}

/**
* Test context method parsing when input is valid
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ public void testStreams() throws IOException {
MethodComponentContext copy = new MethodComponentContext(streamOutput.bytes().streamInput());

assertEquals(original, copy);

// Check that everything works when streams are null
original = new MethodComponentContext(name, null);
streamOutput = new BytesStreamOutput();
original.writeTo(streamOutput);
copy = new MethodComponentContext(streamOutput.bytes().streamInput());
assertEquals(original, copy);
}

/**
Expand Down Expand Up @@ -105,6 +112,25 @@ public void testGetParameters() throws IOException {
MethodComponentContext methodContext = new MethodComponentContext(name, params);
assertEquals(paramVal1, methodContext.getParameters().get(paramKey1));
assertEquals(paramVal2, methodContext.getParameters().get(paramKey2));

// When parameters are null, an empty map should be returned
methodContext = new MethodComponentContext(name, null);
assertTrue(methodContext.getParameters().isEmpty());
}

/**
* Test method component context parsing when parameters are set to null
*/
public void testParse_nullParameters() throws IOException {
String name = "test-name";
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, name)
.field(PARAMETERS, (String) null)
.endObject();
Map<String, Object> in = xContentBuilderToMap(xContentBuilder);
MethodComponentContext methodComponentContext = MethodComponentContext.parse(in);
assertTrue(methodComponentContext.getParameters().isEmpty());
}

/**
Expand Down Expand Up @@ -200,6 +226,15 @@ public void testToXContent() throws IOException {

assertEquals(paramVal1, paramMap.get(paramKey1));
assertEquals(paramVal2, paramMap.get(paramKey2));

// Check when parameters are null
xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, name).field(PARAMETERS, (String) null).endObject();
in = xContentBuilderToMap(xContentBuilder);
methodContext = MethodComponentContext.parse(in);
builder = XContentFactory.jsonBuilder().startObject();
builder = methodContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject();
out = xContentBuilderToMap(builder);
assertNull(out.get(PARAMETERS));
}

public void testEquals() {
Expand All @@ -214,11 +249,15 @@ public void testEquals() {
MethodComponentContext methodContext1 = new MethodComponentContext(name1, parameters1);
MethodComponentContext methodContext2 = new MethodComponentContext(name1, parameters1);
MethodComponentContext methodContext3 = new MethodComponentContext(name2, parameters2);
MethodComponentContext methodContext4 = new MethodComponentContext(name2, null);
MethodComponentContext methodContext5 = new MethodComponentContext(name2, null);

assertEquals(methodContext1, methodContext1);
assertEquals(methodContext1, methodContext2);
assertNotEquals(methodContext1, methodContext3);
assertNotEquals(methodContext1, null);
assertNotEquals(methodContext2, methodContext4);
assertEquals(methodContext4, methodContext5);
}

public void testHashCode() {
Expand All @@ -233,9 +272,13 @@ public void testHashCode() {
MethodComponentContext methodContext1 = new MethodComponentContext(name1, parameters1);
MethodComponentContext methodContext2 = new MethodComponentContext(name1, parameters1);
MethodComponentContext methodContext3 = new MethodComponentContext(name2, parameters2);
MethodComponentContext methodContext4 = new MethodComponentContext(name2, null);
MethodComponentContext methodContext5 = new MethodComponentContext(name2, null);

assertEquals(methodContext1.hashCode(), methodContext1.hashCode());
assertEquals(methodContext1.hashCode(), methodContext2.hashCode());
assertNotEquals(methodContext1.hashCode(), methodContext3.hashCode());
assertNotEquals(methodContext1.hashCode(), methodContext4.hashCode());
assertEquals(methodContext4.hashCode(), methodContext5.hashCode());
}
}

0 comments on commit b08127c

Please sign in to comment.