-
Notifications
You must be signed in to change notification settings - Fork 129
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Adding serialization for filter field in KnnQueryBuilder (#564)
* Adding serialization/deserialization for filter field in Lucene knn query Signed-off-by: Martin Gaievski <[email protected]>
- Loading branch information
1 parent
ce68025
commit 1729748
Showing
8 changed files
with
395 additions
and
16 deletions.
There are no files selected for viewing
86 changes: 86 additions & 0 deletions
86
qa/rolling-upgrade/src/test/java/org/opensearch/knn/bwc/LuceneFilteringIT.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.bwc; | ||
|
||
import org.opensearch.knn.TestUtils; | ||
import org.opensearch.knn.index.query.KNNQueryBuilder; | ||
import org.opensearch.index.query.QueryBuilders; | ||
import org.opensearch.index.query.TermQueryBuilder; | ||
|
||
import org.opensearch.client.Request; | ||
import org.opensearch.client.ResponseException; | ||
import org.opensearch.common.Strings; | ||
import org.opensearch.common.xcontent.ToXContent; | ||
import org.opensearch.common.xcontent.XContentBuilder; | ||
import org.opensearch.common.xcontent.XContentFactory; | ||
|
||
import java.io.IOException; | ||
|
||
import static org.opensearch.knn.TestUtils.NODES_BWC_CLUSTER; | ||
|
||
/** | ||
* Tests scenarios specific to filtering functionality in k-NN in case Lucene is set as an engine | ||
*/ | ||
public class LuceneFilteringIT extends AbstractRollingUpgradeTestCase { | ||
private static final String TEST_FIELD = "test-field"; | ||
private static final int DIMENSIONS = 50; | ||
private static final int K = 10; | ||
private static final int NUM_DOCS = 100; | ||
private static final TermQueryBuilder TERM_QUERY = QueryBuilders.termQuery("_id", "100"); | ||
|
||
public void testLuceneFiltering() throws Exception { | ||
waitForClusterHealthGreen(NODES_BWC_CLUSTER); | ||
float[] queryVector = TestUtils.getQueryVectors(1, DIMENSIONS, NUM_DOCS, true)[0]; | ||
switch (getClusterType()) { | ||
case OLD: | ||
createKnnIndex(testIndex, getKNNDefaultIndexSettings(), createKnnIndexMappingWithLuceneField(TEST_FIELD, DIMENSIONS)); | ||
bulkAddKnnDocs(testIndex, TEST_FIELD, TestUtils.getIndexVectors(NUM_DOCS, DIMENSIONS, true), NUM_DOCS); | ||
validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); | ||
break; | ||
case MIXED: | ||
validateSearchKNNIndexFailed(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); | ||
break; | ||
case UPGRADED: | ||
searchKNNIndex(testIndex, new KNNQueryBuilder(TEST_FIELD, queryVector, K, TERM_QUERY), K); | ||
deleteKNNIndex(testIndex); | ||
break; | ||
} | ||
} | ||
|
||
protected String createKnnIndexMappingWithLuceneField(final String fieldName, int dimension) throws IOException { | ||
return Strings.toString( | ||
XContentFactory.jsonBuilder() | ||
.startObject() | ||
.startObject("properties") | ||
.startObject(fieldName) | ||
.field("type", "knn_vector") | ||
.field("dimension", Integer.toString(dimension)) | ||
.startObject("method") | ||
.field("name", "hnsw") | ||
.field("engine", "lucene") | ||
.field("space_type", "l2") | ||
.endObject() | ||
.endObject() | ||
.endObject() | ||
.endObject() | ||
); | ||
} | ||
|
||
private void validateSearchKNNIndexFailed(String index, KNNQueryBuilder knnQueryBuilder, int resultSize) throws IOException { | ||
XContentBuilder builder = XContentFactory.jsonBuilder().startObject().startObject("query"); | ||
knnQueryBuilder.doXContent(builder, ToXContent.EMPTY_PARAMS); | ||
builder.endObject().endObject(); | ||
|
||
Request request = new Request("POST", "/" + index + "/_search"); | ||
|
||
request.addParameter("size", Integer.toString(resultSize)); | ||
request.addParameter("explain", Boolean.toString(true)); | ||
request.addParameter("search_type", "query_then_fetch"); | ||
request.setJsonEntity(Strings.toString(builder)); | ||
|
||
expectThrows(ResponseException.class, () -> client().performRequest(request)); | ||
} | ||
} |
69 changes: 69 additions & 0 deletions
69
src/main/java/org/opensearch/knn/index/KNNClusterContext.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index; | ||
|
||
import com.carrotsearch.hppc.cursors.ObjectCursor; | ||
import lombok.AccessLevel; | ||
import lombok.NoArgsConstructor; | ||
import lombok.extern.log4j.Log4j2; | ||
import org.opensearch.Version; | ||
import org.opensearch.cluster.node.DiscoveryNode; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.collect.ImmutableOpenMap; | ||
|
||
/** | ||
* Class abstracts information related to underlying OpenSearch cluster | ||
*/ | ||
@NoArgsConstructor(access = AccessLevel.PRIVATE) | ||
@Log4j2 | ||
public class KNNClusterContext { | ||
|
||
private ClusterService clusterService; | ||
private static KNNClusterContext instance; | ||
|
||
/** | ||
* Return instance of the cluster context, must be initialized first for proper usage | ||
* @return instance of cluster context | ||
*/ | ||
public static synchronized KNNClusterContext instance() { | ||
if (instance == null) { | ||
instance = new KNNClusterContext(); | ||
} | ||
return instance; | ||
} | ||
|
||
/** | ||
* Initializes instance of cluster context by injecting dependencies | ||
* @param clusterService | ||
*/ | ||
public void initialize(final ClusterService clusterService) { | ||
this.clusterService = clusterService; | ||
} | ||
|
||
/** | ||
* Return minimal OpenSearch version based on all nodes currently discoverable in the cluster | ||
* @return minimal installed OpenSearch version, default to Version.CURRENT which is typically the latest version | ||
*/ | ||
public Version getClusterMinVersion() { | ||
Version minVersion = Version.CURRENT; | ||
ImmutableOpenMap<String, DiscoveryNode> clusterDiscoveryNodes = ImmutableOpenMap.of(); | ||
log.debug("Reading cluster min version"); | ||
try { | ||
clusterDiscoveryNodes = this.clusterService.state().getNodes().getNodes(); | ||
} catch (Exception exception) { | ||
log.error("Cannot get cluster nodes", exception); | ||
} | ||
for (final ObjectCursor<DiscoveryNode> discoveryNodeCursor : clusterDiscoveryNodes.values()) { | ||
final Version nodeVersion = discoveryNodeCursor.value.getVersion(); | ||
if (nodeVersion.before(minVersion)) { | ||
minVersion = nodeVersion; | ||
log.debug("Update cluster min version to {} based on node {}", nodeVersion, discoveryNodeCursor.value.toString()); | ||
} | ||
} | ||
log.debug("Return cluster min version {}", minVersion); | ||
return minVersion; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
53 changes: 53 additions & 0 deletions
53
src/test/java/org/opensearch/knn/index/KNNClusterContextTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index; | ||
|
||
import org.opensearch.Version; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.knn.KNNTestCase; | ||
|
||
import java.util.List; | ||
|
||
import static org.mockito.Mockito.mock; | ||
import static org.mockito.Mockito.when; | ||
import static org.opensearch.knn.index.KNNClusterTestUtils.mockClusterService; | ||
|
||
public class KNNClusterContextTests extends KNNTestCase { | ||
|
||
public void testSingleNodeCluster() { | ||
ClusterService clusterService = mockClusterService(List.of(Version.V_2_4_0)); | ||
|
||
final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); | ||
knnClusterContext.initialize(clusterService); | ||
|
||
final Version minVersion = knnClusterContext.getClusterMinVersion(); | ||
|
||
assertTrue(Version.V_2_4_0.equals(minVersion)); | ||
} | ||
|
||
public void testMultipleNodesCluster() { | ||
ClusterService clusterService = mockClusterService(List.of(Version.V_3_0_0, Version.V_2_3_0, Version.V_3_0_0)); | ||
|
||
final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); | ||
knnClusterContext.initialize(clusterService); | ||
|
||
final Version minVersion = knnClusterContext.getClusterMinVersion(); | ||
|
||
assertTrue(Version.V_2_3_0.equals(minVersion)); | ||
} | ||
|
||
public void testWhenErrorOnClusterStateDiscover() { | ||
ClusterService clusterService = mock(ClusterService.class); | ||
when(clusterService.state()).thenThrow(new RuntimeException("Cluster state is not ready")); | ||
|
||
final KNNClusterContext knnClusterContext = KNNClusterContext.instance(); | ||
knnClusterContext.initialize(clusterService); | ||
|
||
final Version minVersion = knnClusterContext.getClusterMinVersion(); | ||
|
||
assertTrue(Version.CURRENT.equals(minVersion)); | ||
} | ||
} |
48 changes: 48 additions & 0 deletions
48
src/test/java/org/opensearch/knn/index/KNNClusterTestUtils.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.knn.index; | ||
|
||
import org.opensearch.Version; | ||
import org.opensearch.cluster.ClusterState; | ||
import org.opensearch.cluster.node.DiscoveryNode; | ||
import org.opensearch.cluster.node.DiscoveryNodes; | ||
import org.opensearch.cluster.service.ClusterService; | ||
import org.opensearch.common.collect.ImmutableOpenMap; | ||
|
||
import java.util.List; | ||
|
||
import static org.mockito.Mockito.mock; | ||
import static org.mockito.Mockito.when; | ||
import static org.opensearch.test.OpenSearchTestCase.randomAlphaOfLength; | ||
|
||
/** | ||
* Collection of util methods required for testing and related to OpenSearch cluster setup and functionality | ||
*/ | ||
public class KNNClusterTestUtils { | ||
|
||
/** | ||
* Create new mock for ClusterService | ||
* @param versions list of versions for cluster nodes | ||
* @return | ||
*/ | ||
public static ClusterService mockClusterService(final List<Version> versions) { | ||
ClusterService clusterService = mock(ClusterService.class); | ||
ClusterState clusterState = mock(ClusterState.class); | ||
when(clusterService.state()).thenReturn(clusterState); | ||
DiscoveryNodes discoveryNodes = mock(DiscoveryNodes.class); | ||
when(clusterState.getNodes()).thenReturn(discoveryNodes); | ||
ImmutableOpenMap.Builder<String, DiscoveryNode> builder = ImmutableOpenMap.builder(); | ||
for (Version version : versions) { | ||
DiscoveryNode clusterNode = mock(DiscoveryNode.class); | ||
when(clusterNode.getVersion()).thenReturn(version); | ||
builder.put(randomAlphaOfLength(10), clusterNode); | ||
} | ||
ImmutableOpenMap<String, DiscoveryNode> mapOfNodes = builder.build(); | ||
when(discoveryNodes.getNodes()).thenReturn(mapOfNodes); | ||
|
||
return clusterService; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.