Skip to content

Commit

Permalink
Fix conflict and optimize some function.
Browse files Browse the repository at this point in the history
Signed-off-by: conggguan <[email protected]>
  • Loading branch information
conggguan committed Apr 22, 2024
2 parents 17d7916 + f5a47a6 commit 86a11dd
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@

/**
* Implementation of Query interface for type NeuralSparseQuery when TwoPhaseNeuralSparse Enabled.
* Initialized, it currentQuery include all tokenQuery. After
* Initialized, it currentQuery include all tokenQuery. After call setCurrentQueryToHighScoreTokenQuery,
* it will perform highScoreTokenQuery.
*/
@AllArgsConstructor
@Getter
Expand Down Expand Up @@ -57,8 +58,6 @@ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
public void visit(QueryVisitor queryVisitor) {
QueryVisitor v = queryVisitor.getSubVisitor(BooleanClause.Occur.SHOULD, this);
currentQuery.visit(v);
highScoreTokenQuery.visit(v);
lowScoreTokenQuery.visit(v);
}

@Override
Expand Down Expand Up @@ -86,6 +85,10 @@ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float bo
return currentQuery.createWeight(searcher, scoreMode, boost);
}

/**
* Before call this function, the currentQuery of this object is allTokenQuery.
* After call this function, the currentQuery of this object change to highScoreTokenQuery.
*/
public void setCurrentQueryToHighScoreTokenQuery() {
this.currentQuery = highScoreTokenQuery;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
* "query_tokens": {
* "token_a": float,
* "token_b": float,
* ...
* ...,
* }
* }
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

import org.apache.lucene.search.BooleanQuery;
import org.apache.commons.lang.StringUtils;
import org.apache.lucene.document.FeatureField;
import org.apache.lucene.search.BooleanClause;
Expand Down Expand Up @@ -522,7 +521,7 @@ public void testToXContentWithFullField() {
}

Map<String, Object> secondInnerMap = (Map<String, Object>) secondInner;
assertEquals(5, secondInnerMap.size());
assertEquals(6, secondInnerMap.size());

assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName()));
assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName()));
Expand Down Expand Up @@ -557,8 +556,11 @@ public void testToXContentWithFullField() {
@SneakyThrows
public void testToXContentWithNullableField() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.modelId(MODEL_ID)
.queryText(QUERY_TEXT);
.modelId(null)
.queryText(null)
.maxTokenScore(null)
.queryTokensSupplier(null)
.neuralSparseTwoPhaseParameters(null);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand All @@ -580,10 +582,7 @@ public void testToXContentWithNullableField() {
}

Map<String, Object> secondInnerMap = (Map<String, Object>) secondInner;
assertEquals(3, secondInnerMap.size());

assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName()));
assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName()));
assertEquals(1, secondInnerMap.size());
}

public void testStreams_whenCurrentVersion_thenSuccess() {
Expand Down Expand Up @@ -979,6 +978,7 @@ public void testTokenDividedByScores_whenDefaultSettings() {
assertEquals(lowScoreTokenQuery.clauses().size(), 5);
}

@SneakyThrows
public void testDoToQuery_successfulDoToQuery() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.maxTokenScore(MAX_TOKEN_SCORE)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ public void testBasicQueryUsingQueryTokens() {
}

/**
<<<<<<< HEAD
* Tests basic query with boost:
* {
* "query": {
Expand Down Expand Up @@ -183,6 +184,8 @@ public void testBasicQueryUsingQueryTokens_whenTwoPhaseEnabled() {
}

/**
=======
>>>>>>> origin
* Tests rescore query:
* {
* "query" : {
Expand Down Expand Up @@ -286,11 +289,8 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries() {
* "model_id": "dcsdcasd"
* }
* },
* "neural_sparse": {
* "field2": {
* "query_text": "Hello world a b",
* "model_id": "dcsdcasd"
* }
* "match": {
* "field2": "Hello world a b",
* }
* ]
* }
Expand Down

0 comments on commit 86a11dd

Please sign in to comment.