Skip to content

Commit

Permalink
Adding serialization for filter field in KnnQueryBuilder (opensearch-…
Browse files Browse the repository at this point in the history
…project#564)

* Adding serialization/deserialization for filter field in Lucene knn query

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 20, 2022
1 parent 8c6d4fd commit cdb709a
Show file tree
Hide file tree
Showing 7 changed files with 393 additions and 16 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.bwc;

import org.opensearch.knn.TestUtils;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;

import org.opensearch.client.Request;
import org.opensearch.client.ResponseException;
import org.opensearch.common.Strings;
import org.opensearch.common.xcontent.ToXContent;
import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;

import java.io.IOException;

import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER;

/**
* Tests scenarios specific to filtering functionality in k-NN in case Lucene is set as an engine
*/
public class LuceneFilteringIT extends AbstractRollingUpgradeTestCase {
private static final String TEST_FIELD = "test-field";
private static final int DIMENSIONS = 50;
private static final int K = 10;
private static final int NUM_DOCS = 100;
private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("_id", "100");

public void testLuceneFiltering() throws Exception {
waitForClusterHealthGreen(NODES_BWC_CLUSTER);
float[] queryVector = TestUtils.getQueryVectors(1, DIMENSIONS, NUM_DOCS, true)[0];
switch (getClusterType()) {
case OLD:
createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMappingWithLuceneField(TEST_FIELD, DIMENSIONS));
bulkAddKnnDocs(testIndex, TEST_FIELD, TestUtils.getIndexVectors(NUM_DOCS, DIMENSIONS, true), NUM_DOCS);
validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K);
break;
case MIXED:
validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K);
break;
case UPGRADED:
searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K);
deleteKNNIndex(testIndex);
break;
}
}

protected String createKnnIndexMappingWithLuceneField(final String fieldName, int dimension) throws IOException {
return Strings.toString(
XContentFactory.jsonBuilder()
.startObject()
.startObject("properties")
.startObject(fieldName)
.field("type", "knn_vector")
.field("dimension", Integer.toString(dimension))
.startObject("method")
.field("name", "hnsw")
.field("engine", "lucene")
.field("space_type", "l2")
.endObject()
.endObject()
.endObject()
.endObject()
);
}

private void validateSearchKNNIndexFailed(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws IOException {
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query");
knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS);
builder.endObject().endObject();

Request request = new Request("POST", "/" + index + "/_search");

request.addParameter("size", Integer.toString(resultSize));
request.addParameter("explain", Boolean.toString(true));
request.addParameter("search_type", "query_then_fetch");
request.setJsonEntity(Strings.toString(builder));

expectThrows(ResponseException.class, () -> client().performRequest(request));
}
}
69 changes: 69 additions & 0 deletions src/main/java/org/opensearch/knn/index/KNNClusterContext.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import com.carrotsearch.hppc.cursors.ObjectCursor;
import lombok.AccessLevel;
import lombok.NoArgsConstructor;
import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.ImmutableOpenMap;

/**
* Class abstracts information related to underlying OpenSearch cluster
*/
@NoArgsConstructor(access = AccessLevel.PRIVATE)
@Log4j2
public class KNNClusterContext {

private ClusterService clusterService;
private static KNNClusterContext instance;

/**
* Return instance of the cluster context, must be initialized first for proper usage
* @return instance of cluster context
*/
public static synchronized KNNClusterContext instance() {
if (instance == null) {
instance = new KNNClusterContext();
}
return instance;
}

/**
* Initializes instance of cluster context by injecting dependencies
* @param clusterService
*/
public void initialize(final ClusterService clusterService) {
this.clusterService = clusterService;
}

/**
* Return minimal OpenSearch version based on all nodes currently discoverable in the cluster
* @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version
*/
public Version getClusterMinVersion() {
Version minVersion = Version.CURRENT;
ImmutableOpenMap<String, DiscoveryNode> clusterDiscoveryNodes = ImmutableOpenMap.of();
log.debug("Reading cluster min version");
try {
clusterDiscoveryNodes = this.clusterService.state().getNodes().getNodes();
} catch (Exception exception) {
log.error("Cannot get cluster nodes", exception);
}
for (final ObjectCursor<DiscoveryNode> discoveryNodeCursor : clusterDiscoveryNodes.values()) {
final Version nodeVersion = discoveryNodeCursor.value.getVersion();
if (nodeVersion.before(minVersion)) {
minVersion = nodeVersion;
log.debug("Update cluster min version to {} based on node {}", nodeVersion, discoveryNodeCursor.value.toString());
}
}
log.debug("Return cluster min version {}", minVersion);
return minVersion;
}
}
33 changes: 31 additions & 2 deletions src/main/java/org/opensearch/knn/index/query/KNNQueryBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
package org.opensearch.knn.index.query;

import lombok.extern.log4j.Log4j2;
import org.opensearch.Version;
import org.opensearch.index.mapper.NumberFieldMapper;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.knn.index.KNNClusterContext;
import org.opensearch.knn.index.KNNMethodContext;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
import org.opensearch.knn.index.util.KNNEngine;
Expand Down Expand Up @@ -52,6 +54,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
private final float[] vector;
private int k = 0;
private QueryBuilder filter;
private static final Version MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER = Version.V_3_0_0;

/**
* Constructs a new knn query
Expand Down Expand Up @@ -109,8 +112,11 @@ public KNNQueryBuilder(StreamInput in) throws IOException {
fieldName = in.readString();
vector = in.readFloatArray();
k = in.readInt();
if (isClusterOnOrAfterMinRequiredVersion()) {
filter = in.readOptionalNamedWriteable(QueryBuilder.class);
}
} catch (IOException ex) {
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder: " + ex);
throw new RuntimeException("[KNN] Unable to create KNNQueryBuilder", ex);
}
}

Expand Down Expand Up @@ -152,7 +158,23 @@ public static KNNQueryBuilder fromXContent(XContentParser parser) throws IOExcep
String tokenName = parser.currentName();
if (FILTER_FIELD.getPreferredName().equals(tokenName)) {
log.debug(String.format("Start parsing filter for field [%s]", fieldName));
filter = parseInnerQueryBuilder(parser);
if (isClusterOnOrAfterMinRequiredVersion()) {
filter = parseInnerQueryBuilder(parser);
} else {
log.debug(
String.format(
"This version of k-NN doesn't support [filter] field, minimal required version is [%s]",
MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER
)
);
throw new IllegalArgumentException(
String.format(
"%s field is supported from version %s",
FILTER_FIELD.getPreferredName(),
MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER
)
);
}
} else {
throw new ParsingException(parser.getTokenLocation(), "[" + NAME + "] unknown token [" + token + "]");
}
Expand Down Expand Up @@ -181,6 +203,9 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeFloatArray(vector);
out.writeInt(k);
if (isClusterOnOrAfterMinRequiredVersion()) {
out.writeOptionalNamedWriteable(filter);
}
}

/**
Expand Down Expand Up @@ -294,4 +319,8 @@ protected int doHashCode() {
public String getWriteableName() {
return NAME;
}

private static boolean isClusterOnOrAfterMinRequiredVersion() {
return KNNClusterContext.instance().getClusterMinVersion().onOrAfter(MINIMAL_SUPPORTED_VERSION_FOR_LUCENE_HNSW_FILTER);
}
}
2 changes: 2 additions & 0 deletions src/main/java/org/opensearch/knn/plugin/KNNPlugin.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.opensearch.index.codec.CodecServiceFactory;
import org.opensearch.index.engine.EngineFactory;
import org.opensearch.knn.index.KNNCircuitBreaker;
import org.opensearch.knn.index.KNNClusterContext;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.knn.index.KNNSettings;
import org.opensearch.knn.index.mapper.KNNVectorFieldMapper;
Expand Down Expand Up @@ -179,6 +180,7 @@ public Collection<Object> createComponents(
NativeMemoryLoadStrategy.TrainingLoadStrategy.initialize(vectorReader);

KNNSettings.state().initialize(client, clusterService);
KNNClusterContext.instance().initialize(clusterService);
ModelDao.OpenSearchKNNModelDao.initialize(client, clusterService, environment.settings());
ModelCache.initialize(ModelDao.OpenSearchKNNModelDao.getInstance(), clusterService);
TrainingJobRunner.initialize(threadPool, ModelDao.OpenSearchKNNModelDao.getInstance());
Expand Down
53 changes: 53 additions & 0 deletions src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import org.opensearch.Version;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.knn.KNNTestCase;

import java.util.List;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService;

public class KNNClusterContextTests extends KNNTestCase {

public void testSingleNodeCluster() {
ClusterService clusterService = mockClusterService(List.of(Version.V_2_4_0));

final KNNClusterContext knnClusterContext = KNNClusterContext.instance();
knnClusterContext.initialize(clusterService);

final Version minVersion = knnClusterContext.getClusterMinVersion();

assertTrue(Version.V_2_4_0.equals(minVersion));
}

public void testMultipleNodesCluster() {
ClusterService clusterService = mockClusterService(List.of(Version.V_3_0_0, Version.V_2_3_0, Version.V_3_0_0));

final KNNClusterContext knnClusterContext = KNNClusterContext.instance();
knnClusterContext.initialize(clusterService);

final Version minVersion = knnClusterContext.getClusterMinVersion();

assertTrue(Version.V_2_3_0.equals(minVersion));
}

public void testWhenErrorOnClusterStateDiscover() {
ClusterService clusterService = mock(ClusterService.class);
when(clusterService.state()).thenThrow(new RuntimeException("Cluster state is not ready"));

final KNNClusterContext knnClusterContext = KNNClusterContext.instance();
knnClusterContext.initialize(clusterService);

final Version minVersion = knnClusterContext.getClusterMinVersion();

assertTrue(Version.CURRENT.equals(minVersion));
}
}
48 changes: 48 additions & 0 deletions src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.knn.index;

import org.opensearch.Version;
import org.opensearch.cluster.ClusterState;
import org.opensearch.cluster.node.DiscoveryNode;
import org.opensearch.cluster.node.DiscoveryNodes;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.collect.ImmutableOpenMap;

import java.util.List;

import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.opensearch.test.OpenSearchTestCase.randomAlphaOfLength;

/**
* Collection of util methods required for testing and related to OpenSearch cluster setup and functionality
*/
public class KNNClusterTestUtils {

/**
* Create new mock for ClusterService
* @param versions list of versions for cluster nodes
* @return
*/
public static ClusterService mockClusterService(final List<Version> versions) {
ClusterService clusterService = mock(ClusterService.class);
ClusterState clusterState = mock(ClusterState.class);
when(clusterService.state()).thenReturn(clusterState);
DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class);
when(clusterState.getNodes()).thenReturn(discoveryNodes);
ImmutableOpenMap.Builder<String, DiscoveryNode> builder = ImmutableOpenMap.builder();
for (Version version : versions) {
DiscoveryNode clusterNode = mock(DiscoveryNode.class);
when(clusterNode.getVersion()).thenReturn(version);
builder.put(randomAlphaOfLength(10), clusterNode);
}
ImmutableOpenMap<String, DiscoveryNode> mapOfNodes = builder.build();
when(discoveryNodes.getNodes()).thenReturn(mapOfNodes);

return clusterService;
}
}
Loading

0 comments on commit cdb709a

Please sign in to comment.