Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow null value for params in method mappings #354

Merged
merged 5 commits into from
Apr 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
*/
public class KNNMethodContext implements ToXContentFragment, Writeable {

private static Logger logger = LogManager.getLogger(KNNMethodContext.class);
private static final Logger logger = LogManager.getLogger(KNNMethodContext.class);

private static KNNMethodContext defaultInstance = null;

Expand Down Expand Up @@ -194,6 +194,11 @@ public static KNNMethodContext parse(Object in) {

name = (String) value;
} else if (PARAMETERS.equals(key)) {
if (value == null) {
parameters = null;
continue;
}

if (!(value instanceof Map)) {
throw new MapperParsingException("Unable to parse parameters for main method component");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.index.mapper.MapperParsingException;

import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
Expand All @@ -37,10 +38,10 @@
*/
public class MethodComponentContext implements ToXContentFragment, Writeable {

private static Logger logger = LogManager.getLogger(MethodComponentContext.class);
private static final Logger logger = LogManager.getLogger(MethodComponentContext.class);

private String name;
private Map<String, Object> parameters;
private final String name;
private final Map<String, Object> parameters;

/**
* Constructor
Expand All @@ -61,7 +62,15 @@ public MethodComponentContext(String name, Map<String, Object> parameters) {
*/
public MethodComponentContext(StreamInput in) throws IOException {
this.name = in.readString();
this.parameters = in.readMap(StreamInput::readString, new ParameterMapValueReader());

// Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions,
// do not read if their are no bytes left is null. Make sure this is in sync with the fellow read method. For
// more information, refer to https://github.com/opensearch-project/k-NN/issues/353.
if (in.available() > 0) {
this.parameters = in.readMap(StreamInput::readString, new ParameterMapValueReader());
} else {
this.parameters = null;
}
}

/**
Expand Down Expand Up @@ -93,6 +102,11 @@ public static MethodComponentContext parse(Object in) {
}
name = (String) value;
} else if (PARAMETERS.equals(key)) {
if (value == null) {
parameters = null;
continue;
}

if (!(value instanceof Map)) {
throw new MapperParsingException("Unable to parse parameters for method component");
}
Expand Down Expand Up @@ -125,22 +139,30 @@ public static MethodComponentContext parse(Object in) {
@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
builder.field(NAME, name);
builder.startObject(PARAMETERS);
parameters.forEach((key, value) -> {
try {
if (value instanceof MethodComponentContext) {
builder.startObject(key);
((MethodComponentContext) value).toXContent(builder, params);
builder.endObject();
} else {
builder.field(key, value);
// Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions,
// we just create the null field. If parameters are not null, we created a nested structure. For more
// information, refer to https://github.com/opensearch-project/k-NN/issues/353.
if (parameters == null) {
builder.field(PARAMETERS, (String) null);
} else {
builder.startObject(PARAMETERS);
parameters.forEach((key, value) -> {
try {
if (value instanceof MethodComponentContext) {
builder.startObject(key);
((MethodComponentContext) value).toXContent(builder, params);
builder.endObject();
} else {
builder.field(key, value);
}
} catch (IOException ioe) {
throw new RuntimeException("Unable to generate xcontent for method component");
}
} catch (IOException ioe) {
throw new RuntimeException("Unable to generate xcontent for method component");
}

});
builder.endObject();
});
builder.endObject();
}

return builder;
}

Expand Down Expand Up @@ -176,13 +198,25 @@ public String getName() {
* @return parameters
*/
public Map<String, Object> getParameters() {
// Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions,
// return an empty map if parameters is null. For more information, refer to
// https://github.com/opensearch-project/k-NN/issues/353.
if (parameters == null) {
return Collections.emptyMap();
}
return parameters;
}

@Override
public void writeTo(StreamOutput out) throws IOException {
out.writeString(this.name);
out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter());

// Due to backwards compatibility issue, parameters could be null. To prevent any null pointer exceptions,
// do not write if parameters is null. Make sure this is in sync with the fellow read method. For more
// information, refer to https://github.com/opensearch-project/k-NN/issues/353.
if (this.parameters != null) {
out.writeMap(this.parameters, StreamOutput::writeString, new ParameterMapValueWriter());
}
}

// Because the generic StreamOutput writeMap method can only write generic values, we need to create a custom one
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.http.util.EntityUtils;
import org.opensearch.common.Strings;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.knn.KNNResult;
import org.opensearch.knn.TestUtils;
import org.opensearch.knn.index.KNNQueryBuilder;
Expand All @@ -22,13 +24,18 @@
import org.opensearch.knn.index.SpaceType;
import org.opensearch.test.rest.OpenSearchRestTestCase;
import static org.opensearch.knn.TestUtils.*;
import static org.opensearch.knn.common.KNNConstants.KNN_METHOD;
import static org.opensearch.knn.common.KNNConstants.METHOD_HNSW;
import static org.opensearch.knn.common.KNNConstants.NAME;
import static org.opensearch.knn.common.KNNConstants.PARAMETERS;
import static org.opensearch.knn.index.KNNSettings.KNN_ALGO_PARAM_INDEX_THREAD_QTY;
import static org.opensearch.knn.index.KNNSettings.KNN_MEMORY_CIRCUIT_BREAKER_ENABLED;

public class KNNBackwardsCompatibilityIT extends KNNRestTestCase {
private static final String CLUSTER_NAME = System.getProperty(TEST_CLUSTER_NAME);
private final String testIndexName = KNN_BWC_PREFIX + "test-index";
private final String testIndex_Recall = KNN_BWC_PREFIX + "test-index-recall";
private final String testIndex_NullParams = KNN_BWC_PREFIX + "test-index-null-params";
private final String testFieldName = "test-field";
private final String testField_Recall = "test-field-recall";
private final String testIndex_Recall_Old = KNN_BWC_PREFIX + "test-index-recall-value-old";
Expand Down Expand Up @@ -234,6 +241,47 @@ public void testBackwardsCompatibility() throws Exception {
}
}

public void testNullParametersOnUpgrade() throws Exception {

// Skip test if version is 1.2 or 1.3
// systemProperty 'tests.plugin_bwc_version', knn_bwc_version
String bwcVersion = System.getProperty("tests.plugin_bwc_version", null);
if (bwcVersion == null || bwcVersion.startsWith("1.2") || bwcVersion.startsWith("1.3")) {
return;
}

// Confirm cluster is green before starting
Request waitForGreen = new Request("GET", "/_cluster/health");
waitForGreen.addParameter("wait_for_nodes", "3");
waitForGreen.addParameter("wait_for_status", "green");
client().performRequest(waitForGreen);

switch (getClusterType()) {
case OLD:
String mapping = Strings.toString(
XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(testFieldName)
.field("type", "knn_vector")
.field("dimension", String.valueOf(dimensions))
.startObject(KNN_METHOD)
.field(NAME, METHOD_HNSW)
.field(PARAMETERS, (String) null)
.endObject()
.endObject()
.endObject()
.endObject()
);

createKnnIndex(testIndex_NullParams, getKNNDefaultIndexSettings(), mapping);
break;
case UPGRADED:
deleteKNNIndex(testIndex_NullParams);
break;
}
}

private String getUri(ClusterType clusterType) {
switch (clusterType) {
case OLD:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,21 @@ public void testParse_invalid() throws IOException {
expectThrows(MapperParsingException.class, () -> MethodComponentContext.parse(in7));
}

/**
* Test context method parsing when parameters are set to null
*/
public void testParse_nullParameters() throws IOException {
String methodName = "test-method";
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, methodName)
.field(PARAMETERS, (String) null)
.endObject();
Map<String, Object> in = xContentBuilderToMap(xContentBuilder);
KNNMethodContext knnMethodContext = KNNMethodContext.parse(in);
assertTrue(knnMethodContext.getMethodComponent().getParameters().isEmpty());
}

/**
* Test context method parsing when input is valid
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ public void testStreams() throws IOException {
MethodComponentContext copy = new MethodComponentContext(streamOutput.bytes().streamInput());

assertEquals(original, copy);

// Check that everything works when streams are null
original = new MethodComponentContext(name, null);
streamOutput = new BytesStreamOutput();
original.writeTo(streamOutput);
copy = new MethodComponentContext(streamOutput.bytes().streamInput());
assertEquals(original, copy);
}

/**
Expand Down Expand Up @@ -105,6 +112,25 @@ public void testGetParameters() throws IOException {
MethodComponentContext methodContext = new MethodComponentContext(name, params);
assertEquals(paramVal1, methodContext.getParameters().get(paramKey1));
assertEquals(paramVal2, methodContext.getParameters().get(paramKey2));

// When parameters are null, an empty map should be returned
methodContext = new MethodComponentContext(name, null);
assertTrue(methodContext.getParameters().isEmpty());
}

/**
* Test method component context parsing when parameters are set to null
*/
public void testParse_nullParameters() throws IOException {
String name = "test-name";
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.field(NAME, name)
.field(PARAMETERS, (String) null)
.endObject();
Map<String, Object> in = xContentBuilderToMap(xContentBuilder);
MethodComponentContext methodComponentContext = MethodComponentContext.parse(in);
assertTrue(methodComponentContext.getParameters().isEmpty());
}

/**
Expand Down Expand Up @@ -200,6 +226,15 @@ public void testToXContent() throws IOException {

assertEquals(paramVal1, paramMap.get(paramKey1));
assertEquals(paramVal2, paramMap.get(paramKey2));

// Check when parameters are null
xContentBuilder = XContentFactory.jsonBuilder().startObject().field(NAME, name).field(PARAMETERS, (String) null).endObject();
in = xContentBuilderToMap(xContentBuilder);
methodContext = MethodComponentContext.parse(in);
builder = XContentFactory.jsonBuilder().startObject();
builder = methodContext.toXContent(builder, ToXContent.EMPTY_PARAMS).endObject();
out = xContentBuilderToMap(builder);
assertNull(out.get(PARAMETERS));
}

public void testEquals() {
Expand All @@ -214,11 +249,15 @@ public void testEquals() {
MethodComponentContext methodContext1 = new MethodComponentContext(name1, parameters1);
MethodComponentContext methodContext2 = new MethodComponentContext(name1, parameters1);
MethodComponentContext methodContext3 = new MethodComponentContext(name2, parameters2);
MethodComponentContext methodContext4 = new MethodComponentContext(name2, null);
MethodComponentContext methodContext5 = new MethodComponentContext(name2, null);

assertEquals(methodContext1, methodContext1);
assertEquals(methodContext1, methodContext2);
assertNotEquals(methodContext1, methodContext3);
assertNotEquals(methodContext1, null);
assertNotEquals(methodContext2, methodContext4);
assertEquals(methodContext4, methodContext5);
}

public void testHashCode() {
Expand All @@ -233,9 +272,13 @@ public void testHashCode() {
MethodComponentContext methodContext1 = new MethodComponentContext(name1, parameters1);
MethodComponentContext methodContext2 = new MethodComponentContext(name1, parameters1);
MethodComponentContext methodContext3 = new MethodComponentContext(name2, parameters2);
MethodComponentContext methodContext4 = new MethodComponentContext(name2, null);
MethodComponentContext methodContext5 = new MethodComponentContext(name2, null);

assertEquals(methodContext1.hashCode(), methodContext1.hashCode());
assertEquals(methodContext1.hashCode(), methodContext2.hashCode());
assertNotEquals(methodContext1.hashCode(), methodContext3.hashCode());
assertNotEquals(methodContext1.hashCode(), methodContext4.hashCode());
assertEquals(methodContext4.hashCode(), methodContext5.hashCode());
}
}