Skip to content

Commit

Permalink
Added unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Nov 23, 2024
1 parent 2edefec commit 96e0869
Show file tree
Hide file tree
Showing 3 changed files with 451 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/
package org.opensearch.neuralsearch.query;

import com.google.common.annotations.VisibleForTesting;
import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.apache.lucene.search.DisiPriorityQueue;
Expand All @@ -30,7 +31,7 @@
* corresponds to order of sub-queries in an input Hybrid query.
*/
@Log4j2
public final class HybridQueryScorer extends Scorer {
public class HybridQueryScorer extends Scorer {

// score for each of sub-query in this hybrid query
@Getter
Expand Down Expand Up @@ -100,7 +101,8 @@ public float score() throws IOException {
return score(getSubMatches());
}

private float score(DisiWrapper topList) throws IOException {
@VisibleForTesting
float score(DisiWrapper topList) throws IOException {
float totalScore = 0.0f;
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.mapper.TextFieldMapper;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
Expand Down Expand Up @@ -756,6 +757,69 @@ public void testBoost_whenDefaultBoostSet_thenBuildSuccessfully() {
assertNotNull(hybridQueryBuilder);
}

@SneakyThrows
public void testBuild_whenValidParameters_thenCreateQuery() {
String queryText = "test query";
String modelId = "test_model";
String fieldName = "rank_features";

// Create mock context
QueryShardContext context = mock(QueryShardContext.class);
MappedFieldType fieldType = mock(MappedFieldType.class);
when(context.fieldMapper(fieldName)).thenReturn(fieldType);
when(fieldType.typeName()).thenReturn("rank_features");

// Create HybridQueryBuilder instance (no spy since it's final)
NeuralSparseQueryBuilder neuralSparseQueryBuilder = new NeuralSparseQueryBuilder();
neuralSparseQueryBuilder.fieldName(fieldName)
.queryText(queryText)
.modelId(modelId)
.queryTokensSupplier(() -> Map.of("token1", 1.0f, "token2", 0.5f));
HybridQueryBuilder builder = new HybridQueryBuilder().add(neuralSparseQueryBuilder);

// Build query
Query query = builder.toQuery(context);

// Verify
assertNotNull("Query should not be null", query);
assertTrue("Should be HybridQuery", query instanceof HybridQuery);
}

@SneakyThrows
public void testDoEquals_whenSameParameters_thenEqual() {
// Create neural queries
NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder().queryText("test").modelId("test_model");

NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder().queryText("test").modelId("test_model");

// Create neural sparse queries with queryTokensSupplier
NeuralSparseQueryBuilder neuralSparseQueryBuilder1 = new NeuralSparseQueryBuilder().fieldName("test_field")
.queryText("test")
.modelId("test_model")
.queryTokensSupplier(() -> Map.of("token1", 1.0f));

NeuralSparseQueryBuilder neuralSparseQueryBuilder2 = new NeuralSparseQueryBuilder().fieldName("test_field")
.queryText("test")
.modelId("test_model")
.queryTokensSupplier(() -> Map.of("token1", 1.0f));

// Create builders
HybridQueryBuilder builder1 = new HybridQueryBuilder().add(neuralQueryBuilder1).add(neuralSparseQueryBuilder1);

HybridQueryBuilder builder2 = new HybridQueryBuilder().add(neuralQueryBuilder2).add(neuralSparseQueryBuilder2);

// Verify
assertTrue("Builders should be equal", builder1.equals(builder2));
assertEquals("Hash codes should match", builder1.hashCode(), builder2.hashCode());
}

public void testValidate_whenInvalidParameters_thenThrowException() {
// Test null query builder
HybridQueryBuilder builderWithNull = new HybridQueryBuilder();
IllegalArgumentException nullException = assertThrows(IllegalArgumentException.class, () -> builderWithNull.add(null));
assertEquals("inner hybrid query clause cannot be null", nullException.getMessage());
}

public void testVisit() {
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder().add(new NeuralQueryBuilder()).add(new NeuralSparseQueryBuilder());
List<QueryBuilder> visitedQueries = new ArrayList<>();
Expand Down
Loading

0 comments on commit 96e0869

Please sign in to comment.