Skip to content

Commit

Permalink
Fix equals and hashCode methods for KNNQuery and KNNQueryBuilder (#1397)
Browse files Browse the repository at this point in the history
Signed-off-by: panguixin <[email protected]>
  • Loading branch information
bugmakerrrrrr authored Jan 23, 2024
1 parent fcbfef1 commit 89fc267
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 10 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
* Properly designate model state for actively training models when nodes crash or leave cluster [#1317](https://github.com/opensearch-project/k-NN/pull/1317)
* Fix script score queries not getting cached [#1367](https://github.com/opensearch-project/k-NN/pull/1367)
* Fix KNNScorer to apply boost [#1403](https://github.com/opensearch-project/k-NN/pull/1403)
* Fix equals and hashCode methods for KNNQuery and KNNQueryBuilder [#1397](https://github.com/opensearch-project/k-NN/pull/1397)
### Infrastructure
* Upgrade gradle to 8.4 [1289](https://github.com/opensearch-project/k-NN/pull/1289)
* Refactor security testing to install from individual components [#1307](https://github.com/opensearch-project/k-NN/pull/1307)
Expand Down
12 changes: 9 additions & 3 deletions src/main/java/org/opensearch/knn/index/query/KNNQuery.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.opensearch.knn.index.query;

import java.util.Arrays;
import java.util.Objects;
import lombok.Getter;
import lombok.Setter;
import org.apache.lucene.search.BooleanClause;
Expand Down Expand Up @@ -127,7 +129,7 @@ public String toString(String field) {

@Override
public int hashCode() {
return field.hashCode() ^ queryVector.hashCode() ^ k;
return Objects.hash(field, Arrays.hashCode(queryVector), k, indexName, filterQuery);
}

@Override
Expand All @@ -136,6 +138,10 @@ public boolean equals(Object other) {
}

private boolean equalsTo(KNNQuery other) {
return this.field.equals(other.getField()) && this.queryVector.equals(other.getQueryVector()) && this.k == other.getK();
return Objects.equals(field, other.field)
&& Arrays.equals(queryVector, other.queryVector)
&& Objects.equals(k, other.k)
&& Objects.equals(indexName, other.indexName)
&& Objects.equals(filterQuery, other.filterQuery);
}
};
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.knn.index.query;

import java.util.Arrays;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.opensearch.core.common.Strings;
Expand Down Expand Up @@ -46,7 +47,7 @@ public class KNNQueryBuilder extends AbstractQueryBuilder<KNNQueryBuilder> {
public static final ParseField K_FIELD = new ParseField("k");
public static final ParseField FILTER_FIELD = new ParseField("filter");
public static final ParseField IGNORE_UNMAPPED_FIELD = new ParseField("ignore_unmapped");
public static int K_MAX = 10000;
public static final int K_MAX = 10000;
/**
* The name for the knn query
*/
Expand Down Expand Up @@ -346,12 +347,16 @@ private ModelMetadata getModelMetadataForField(KNNVectorFieldMapper.KNNVectorFie

@Override
protected boolean doEquals(KNNQueryBuilder other) {
return Objects.equals(fieldName, other.fieldName) && Objects.equals(vector, other.vector) && Objects.equals(k, other.k);
return Objects.equals(fieldName, other.fieldName)
&& Arrays.equals(vector, other.vector)
&& Objects.equals(k, other.k)
&& Objects.equals(filter, other.filter)
&& Objects.equals(ignoreUnmapped, other.ignoreUnmapped);
}

@Override
protected int doHashCode() {
return Objects.hash(fieldName, vector, k);
return Objects.hash(fieldName, Arrays.hashCode(vector), k, filter, ignoreUnmapped);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ public void testEmptyVector() {
expectThrows(IllegalArgumentException.class, () -> new KNNQueryBuilder(FIELD_NAME, queryVector1, K));
}

public void testFromXcontent() throws Exception {
public void testFromXContent() throws Exception {
float[] queryVector = { 1.0f, 2.0f, 3.0f, 4.0f };
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder(FIELD_NAME, queryVector, K);
XContentBuilder builder = XContentFactory.jsonBuilder();
Expand All @@ -103,10 +103,10 @@ public void testFromXcontent() throws Exception {
XContentParser contentParser = createParser(builder);
contentParser.nextToken();
KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser);
actualBuilder.equals(knnQueryBuilder);
assertEquals(knnQueryBuilder, actualBuilder);
}

public void testFromXcontent_WithFilter() throws Exception {
public void testFromXContent_WithFilter() throws Exception {
final ClusterService clusterService = mockClusterService(Version.CURRENT);

final KNNClusterUtil knnClusterUtil = KNNClusterUtil.instance();
Expand All @@ -125,7 +125,7 @@ public void testFromXcontent_WithFilter() throws Exception {
XContentParser contentParser = createParser(builder);
contentParser.nextToken();
KNNQueryBuilder actualBuilder = KNNQueryBuilder.fromXContent(contentParser);
actualBuilder.equals(knnQueryBuilder);
assertEquals(knnQueryBuilder, actualBuilder);
}

public void testFromXContent_invalidQueryVectorType() throws Exception {
Expand Down

0 comments on commit 89fc267

Please sign in to comment.