Skip to content

Commit

Permalink
Adding alternate multi_match syntax.
Browse files Browse the repository at this point in the history
Signed-off-by: forestmvey <[email protected]>
  • Loading branch information
forestmvey committed Dec 6, 2022
1 parent 43ceda1 commit 52b8c05
Show file tree
Hide file tree
Showing 10 changed files with 288 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,9 @@ public enum BuiltinFunctionName {
QUERY(FunctionName.of("query")),
MATCH_QUERY(FunctionName.of("match_query")),
MATCHQUERY(FunctionName.of("matchquery")),
MULTI_MATCH(FunctionName.of("multi_match"));
MULTI_MATCH(FunctionName.of("multi_match")),
MULTIMATCH(FunctionName.of("multimatch")),
MULTIMATCHQUERY(FunctionName.of("multimatchquery"));

private final FunctionName name;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ public class OpenSearchFunctions {
*/
public void register(BuiltinFunctionRepository repository) {
repository.register(match_bool_prefix());
repository.register(multi_match(BuiltinFunctionName.MULTI_MATCH));
repository.register(multi_match(BuiltinFunctionName.MULTIMATCH));
repository.register(multi_match(BuiltinFunctionName.MULTIMATCHQUERY));
repository.register(match(BuiltinFunctionName.MATCH));
repository.register(match(BuiltinFunctionName.MATCHQUERY));
repository.register(match(BuiltinFunctionName.MATCH_QUERY));
repository.register(multi_match());
repository.register(simple_query_string());
repository.register(query());
repository.register(query_string());
Expand Down Expand Up @@ -62,9 +64,8 @@ private static FunctionResolver match_phrase(BuiltinFunctionName matchPhrase) {
return new RelevanceFunctionResolver(funcName, STRING);
}

private static FunctionResolver multi_match() {
FunctionName funcName = BuiltinFunctionName.MULTI_MATCH.getName();
return new RelevanceFunctionResolver(funcName, STRUCT);
private static FunctionResolver multi_match(BuiltinFunctionName multiMatchName) {
return new RelevanceFunctionResolver(multiMatchName.getName(), STRUCT);
}

private static FunctionResolver simple_query_string() {
Expand Down
79 changes: 71 additions & 8 deletions integ-test/src/test/java/org/opensearch/sql/sql/MultiMatchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
package org.opensearch.sql.sql;

import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BEER;
import static org.opensearch.sql.util.MatcherUtils.rows;
import static org.opensearch.sql.util.MatcherUtils.verifyDataRows;

import java.io.IOException;
import org.json.JSONObject;
Expand All @@ -27,38 +29,99 @@ public void init() throws IOException {
*/

@Test
public void test_mandatory_params() throws IOException {
public void test_mandatory_params() {
String query = "SELECT Id FROM " + TEST_INDEX_BEER
+ " WHERE multi_match([\\\"Tags\\\" ^ 1.5, Title, `Body` 4.2], 'taste')";
var result = new JSONObject(executeQuery(query, "jdbc"));
JSONObject result = executeJdbcRequest(query);
assertEquals(16, result.getInt("total"));
}

@Test
public void test_all_params() throws IOException {
public void test_all_params() {
String query = "SELECT Id FROM " + TEST_INDEX_BEER
+ " WHERE multi_match(['Body', Tags], 'taste beer', operator='and', analyzer=english,"
+ "auto_generate_synonyms_phrase_query=true, boost = 0.77, cutoff_frequency=0.33,"
+ "fuzziness = 'AUTO:1,5', fuzzy_transpositions = false, lenient = true, max_expansions = 25,"
+ "minimum_should_match = '2<-25% 9<-3', prefix_length = 7, tie_breaker = 0.3,"
+ "type = most_fields, slop = 2, zero_terms_query = 'ALL');";
var result = new JSONObject(executeQuery(query, "jdbc"));
JSONObject result = executeJdbcRequest(query);
assertEquals(10, result.getInt("total"));
}

@Test
public void verify_wildcard_test() throws IOException {
public void verify_wildcard_test() {
String query1 = "SELECT Id FROM " + TEST_INDEX_BEER
+ " WHERE multi_match(['Tags'], 'taste')";
var result1 = new JSONObject(executeQuery(query1, "jdbc"));
JSONObject result1 = executeJdbcRequest(query1);
String query2 = "SELECT Id FROM " + TEST_INDEX_BEER
+ " WHERE multi_match(['T*'], 'taste')";
var result2 = new JSONObject(executeQuery(query2, "jdbc"));
JSONObject result2 = executeJdbcRequest(query2);
assertNotEquals(result2.getInt("total"), result1.getInt("total"));

String query = "SELECT Id FROM " + TEST_INDEX_BEER
+ " WHERE multi_match(['*Date'], '2014-01-22');";
var result = new JSONObject(executeQuery(query, "jdbc"));
JSONObject result = executeJdbcRequest(query);
assertEquals(10, result.getInt("total"));
}

@Test
public void test_multimatch_alternate_parameter_syntax() {
String query = "SELECT Tags FROM " + TEST_INDEX_BEER
+ " WHERE multimatch('query'='taste', 'fields'='Tags')";
JSONObject result = executeJdbcRequest(query);
assertEquals(8, result.getInt("total"));
}

@Test
public void test_multimatchquery_alternate_parameter_syntax() {
String query = "SELECT Tags FROM " + TEST_INDEX_BEER
+ " WHERE multimatchquery(query='cicerone', fields='Tags')";
JSONObject result = executeJdbcRequest(query);
assertEquals(2, result.getInt("total"));
verifyDataRows(result, rows("serving cicerone restaurants"),
rows("taste cicerone"));
}

@Test
public void test_quoted_multi_match_alternate_parameter_syntax() {
String query = "SELECT Tags FROM " + TEST_INDEX_BEER
+ " WHERE multi_match('query'='cicerone', 'fields'='Tags')";
JSONObject result = executeJdbcRequest(query);
assertEquals(2, result.getInt("total"));
verifyDataRows(result, rows("serving cicerone restaurants"),
rows("taste cicerone"));
}

@Test
public void test_multi_match_alternate_parameter_syntax() {
String query = "SELECT Tags FROM " + TEST_INDEX_BEER
+ " WHERE multi_match(query='cicerone', fields='Tags')";
JSONObject result = executeJdbcRequest(query);
assertEquals(2, result.getInt("total"));
verifyDataRows(result, rows("serving cicerone restaurants"),
rows("taste cicerone"));
}

@Test
public void test_wildcard_multi_match_alternate_parameter_syntax() {
String query = "SELECT Body FROM " + TEST_INDEX_BEER
+ " WHERE multi_match(query='IPA', fields='B*') LIMIT 1";
JSONObject result = executeJdbcRequest(query);
verifyDataRows(result, rows("<p>I know what makes an IPA an IPA, but what are the unique" +
" characteristics of it's common variants? To be specific, the ones I'm interested in are Double IPA" +
" and Black IPA, but general differences between any other styles would be welcome too. </p>\n"));
}

@Test
public void test_all_params_multimatchquery_alternate_parameter_syntax() {
String query = "SELECT Id FROM " + TEST_INDEX_BEER
+ " WHERE multimatchquery(query='cicerone', fields='Tags', 'operator'='or', analyzer=english,"
+ "auto_generate_synonyms_phrase_query=true, boost = 0.77, cutoff_frequency=0.33,"
+ "fuzziness = 'AUTO:1,5', fuzzy_transpositions = false, lenient = true, max_expansions = 25,"
+ "minimum_should_match = '2<-25% 9<-3', prefix_length = 7, tie_breaker = 0.3,"
+ "type = most_fields, slop = 2, zero_terms_query = 'ALL');";

JSONObject result = executeJdbcRequest(query);
assertEquals(2, result.getInt("total"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ public class FilterQueryBuilder extends ExpressionNodeVisitor<QueryBuilder, Obje
.put(BuiltinFunctionName.MATCH_QUERY.getName(), new MatchQuery())
.put(BuiltinFunctionName.MATCHQUERY.getName(), new MatchQuery())
.put(BuiltinFunctionName.MULTI_MATCH.getName(), new MultiMatchQuery())
.put(BuiltinFunctionName.MULTIMATCH.getName(), new MultiMatchQuery())
.put(BuiltinFunctionName.MULTIMATCHQUERY.getName(), new MultiMatchQuery())
.put(BuiltinFunctionName.SIMPLE_QUERY_STRING.getName(), new SimpleQueryStringQuery())
.put(BuiltinFunctionName.QUERY_STRING.getName(), new QueryStringQuery())
.put(BuiltinFunctionName.MATCH_BOOL_PREFIX.getName(), new MatchBoolPrefixQuery())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@
@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class)
class MultiMatchTest {
private final MultiMatchQuery multiMatchQuery = new MultiMatchQuery();
private final FunctionName multiMatch = FunctionName.of("multi_match");
private final FunctionName multiMatchName = FunctionName.of("multimatch");
private final FunctionName snakeCaseMultiMatchName = FunctionName.of("multi_match");
private final FunctionName multiMatchQueryName = FunctionName.of("multimatchquery");
private static final LiteralExpression fields_value = DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"title", ExprValueUtils.floatValue(1.F),
Expand Down Expand Up @@ -129,27 +131,69 @@ static Stream<List<Expression>> generateValidData() {

@ParameterizedTest
@MethodSource("generateValidData")
public void test_valid_parameters(List<Expression> validArgs) {
public void test_valid_parameters_multiMatch(List<Expression> validArgs) {
Assertions.assertNotNull(multiMatchQuery.build(
new MultiMatchExpression(validArgs)));
}

@ParameterizedTest
@MethodSource("generateValidData")
public void test_valid_parameters_multi_match(List<Expression> validArgs) {
Assertions.assertNotNull(multiMatchQuery.build(
new MultiMatchExpression(validArgs, snakeCaseMultiMatchName)));
}

@ParameterizedTest
@MethodSource("generateValidData")
public void test_valid_parameters_multiMatchQuery(List<Expression> validArgs) {
Assertions.assertNotNull(multiMatchQuery.build(
new MultiMatchExpression(validArgs, multiMatchQueryName)));
}

@Test
public void test_SyntaxCheckException_when_no_arguments() {
public void test_SyntaxCheckException_when_no_arguments_multiMatch() {
List<Expression> arguments = List.of();
assertThrows(SyntaxCheckException.class,
() -> multiMatchQuery.build(new MultiMatchExpression(arguments)));
}

@Test
public void test_SyntaxCheckException_when_one_argument() {
public void test_SyntaxCheckException_when_no_arguments_multi_match() {
List<Expression> arguments = List.of();
assertThrows(SyntaxCheckException.class,
() -> multiMatchQuery.build(new MultiMatchExpression(arguments, multiMatchName)));
}

@Test
public void test_SyntaxCheckException_when_no_arguments_multiMatchQuery() {
List<Expression> arguments = List.of();
assertThrows(SyntaxCheckException.class,
() -> multiMatchQuery.build(new MultiMatchExpression(arguments, multiMatchQueryName)));
}

@Test
public void test_SyntaxCheckException_when_one_argument_multiMatch() {
List<Expression> arguments = List.of(namedArgument("fields", fields_value));
assertThrows(SyntaxCheckException.class,
() -> multiMatchQuery.build(new MultiMatchExpression(arguments)));
}

@Test
public void test_SemanticCheckException_when_invalid_parameter() {
public void test_SyntaxCheckException_when_one_argument_multi_match() {
List<Expression> arguments = List.of(namedArgument("fields", fields_value));
assertThrows(SyntaxCheckException.class,
() -> multiMatchQuery.build(new MultiMatchExpression(arguments, snakeCaseMultiMatchName)));
}

@Test
public void test_SyntaxCheckException_when_one_argument_multiMatchQuery() {
List<Expression> arguments = List.of(namedArgument("fields", fields_value));
assertThrows(SyntaxCheckException.class,
() -> multiMatchQuery.build(new MultiMatchExpression(arguments, multiMatchQueryName)));
}

@Test
public void test_SemanticCheckException_when_invalid_parameter_multiMatch() {
List<Expression> arguments = List.of(
namedArgument("fields", fields_value),
namedArgument("query", query_value),
Expand All @@ -158,15 +202,40 @@ public void test_SemanticCheckException_when_invalid_parameter() {
() -> multiMatchQuery.build(new MultiMatchExpression(arguments)));
}

@Test
public void test_SemanticCheckException_when_invalid_parameter_multi_match() {
List<Expression> arguments = List.of(
namedArgument("fields", fields_value),
namedArgument("query", query_value),
DSL.namedArgument("unsupported", "unsupported_value"));
Assertions.assertThrows(SemanticCheckException.class,
() -> multiMatchQuery.build(new MultiMatchExpression(arguments, snakeCaseMultiMatchName)));
}

@Test
public void test_SemanticCheckException_when_invalid_parameter_multiMatchQuery() {
List<Expression> arguments = List.of(
namedArgument("fields", fields_value),
namedArgument("query", query_value),
DSL.namedArgument("unsupported", "unsupported_value"));
Assertions.assertThrows(SemanticCheckException.class,
() -> multiMatchQuery.build(new MultiMatchExpression(arguments, multiMatchQueryName)));
}

private NamedArgumentExpression namedArgument(String name, LiteralExpression value) {
return DSL.namedArgument(name, value);
}

private class MultiMatchExpression extends FunctionExpression {
public MultiMatchExpression(List<Expression> arguments) {
super(MultiMatchTest.this.multiMatch, arguments);
super(multiMatchName, arguments);
}

public MultiMatchExpression(List<Expression> arguments, FunctionName funcName) {
super(funcName, arguments);
}


@Override
public ExprValue valueOf(Environment<Expression, ExprValue> valueEnv) {
throw new UnsupportedOperationException("Invalid function call, "
Expand Down
1 change: 1 addition & 0 deletions sql/src/main/antlr/OpenSearchSQLLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,7 @@ MINUTE_OF_HOUR: 'MINUTE_OF_HOUR';
MONTH_OF_YEAR: 'MONTH_OF_YEAR';
MULTIMATCH: 'MULTIMATCH';
MULTI_MATCH: 'MULTI_MATCH';
MULTIMATCHQUERY: 'MULTIMATCHQUERY';
NESTED: 'NESTED';
PERCENTILES: 'PERCENTILES';
REGEXP_QUERY: 'REGEXP_QUERY';
Expand Down
20 changes: 20 additions & 0 deletions sql/src/main/antlr/OpenSearchSQLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,8 @@ multiFieldRelevanceFunction
: multiFieldRelevanceFunctionName LR_BRACKET
LT_SQR_PRTHS field=relevanceFieldAndWeight (COMMA field=relevanceFieldAndWeight)* RT_SQR_PRTHS
COMMA query=relevanceQuery (COMMA relevanceArg)* RR_BRACKET
| multiFieldRelevanceFunctionName LR_BRACKET
alternateMultiMatchQuery COMMA alternateMultiMatchField (COMMA relevanceArg)* RR_BRACKET
;

convertedDataType
Expand Down Expand Up @@ -467,6 +469,8 @@ singleFieldRelevanceFunctionName

multiFieldRelevanceFunctionName
: MULTI_MATCH
| MULTIMATCH
| MULTIMATCHQUERY
| SIMPLE_QUERY_STRING
| QUERY_STRING
;
Expand All @@ -481,6 +485,7 @@ functionArg

relevanceArg
: relevanceArgName EQUAL_SYMBOL relevanceArgValue
| argName=stringLiteral EQUAL_SYMBOL argVal=relevanceArgValue
;

highlightArg
Expand Down Expand Up @@ -530,3 +535,18 @@ highlightArgValue
: stringLiteral
;

alternateMultiMatchArgName
: FIELDS
| QUERY
| stringLiteral
;

alternateMultiMatchQuery
: argName=alternateMultiMatchArgName EQUAL_SYMBOL argVal=relevanceArgValue
;

alternateMultiMatchField
: argName=alternateMultiMatchArgName EQUAL_SYMBOL argVal=relevanceArgValue
| argName=alternateMultiMatchArgName EQUAL_SYMBOL
LT_SQR_PRTHS argVal=relevanceArgValue RT_SQR_PRTHS
;
Loading

0 comments on commit 52b8c05

Please sign in to comment.