Skip to content

Commit

Permalink
Fix relevance function fields are permissive when fields are missing.
Browse files Browse the repository at this point in the history
Signed-off-by: forestmvey <[email protected]>
  • Loading branch information
forestmvey committed Nov 10, 2022
1 parent 03f30e3 commit 641f751
Show file tree
Hide file tree
Showing 5 changed files with 221 additions and 9 deletions.
12 changes: 4 additions & 8 deletions core/src/main/java/org/opensearch/sql/analysis/Analyzer.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,18 +11,10 @@
import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC;
import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;
import static org.opensearch.sql.utils.MLCommonsConstants.ACTION;
import static org.opensearch.sql.utils.MLCommonsConstants.MODELID;
import static org.opensearch.sql.utils.MLCommonsConstants.PREDICT;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE;
import static org.opensearch.sql.utils.MLCommonsConstants.RCF_TIMESTAMP;
import static org.opensearch.sql.utils.MLCommonsConstants.STATUS;
import static org.opensearch.sql.utils.MLCommonsConstants.TASKID;
import static org.opensearch.sql.utils.MLCommonsConstants.TIME_FIELD;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAIN;
import static org.opensearch.sql.utils.MLCommonsConstants.TRAINANDPREDICT;
import static org.opensearch.sql.utils.SystemIndexUtils.CATALOGS_TABLE_NAME;

import com.google.common.collect.ImmutableList;
Expand Down Expand Up @@ -76,13 +68,15 @@
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.FunctionExpression;
import org.opensearch.sql.expression.LiteralExpression;
import org.opensearch.sql.expression.NamedExpression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.expression.aggregation.Aggregator;
import org.opensearch.sql.expression.aggregation.NamedAggregator;
import org.opensearch.sql.expression.function.BuiltinFunctionRepository;
import org.opensearch.sql.expression.function.FunctionName;
import org.opensearch.sql.expression.function.OpenSearchFunctions;
import org.opensearch.sql.expression.function.TableFunctionImplementation;
import org.opensearch.sql.expression.parse.ParseExpression;
import org.opensearch.sql.planner.logical.LogicalAD;
Expand Down Expand Up @@ -225,6 +219,8 @@ public LogicalPlan visitFilter(Filter node, AnalysisContext context) {
LogicalPlan child = node.getChild().get(0).accept(this, context);
Expression condition = expressionAnalyzer.analyze(node.getCondition(), context);

OpenSearchFunctions.validateFieldList((FunctionExpression)condition, context);

ExpressionReferenceOptimizer optimizer =
new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child);
Expression optimized = optimizer.optimize(condition, context);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,15 @@
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.STRUCT;

import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
import lombok.experimental.UtilityClass;
import org.opensearch.sql.analysis.AnalysisContext;
import org.opensearch.sql.analysis.TypeEnvironment;
import org.opensearch.sql.analysis.symbol.Namespace;
import org.opensearch.sql.analysis.symbol.Symbol;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.data.type.ExprType;
Expand All @@ -22,6 +27,66 @@

@UtilityClass
public class OpenSearchFunctions {
private final List<String> singleFieldFunctionNames = ImmutableList.of(
BuiltinFunctionName.MATCH.name(),
BuiltinFunctionName.MATCH_BOOL_PREFIX.name(),
BuiltinFunctionName.MATCHPHRASE.name(),
BuiltinFunctionName.MATCH_PHRASE_PREFIX.name()
);

private final List<String> multiFieldFunctionNames = ImmutableList.of(
BuiltinFunctionName.MULTI_MATCH.name(),
BuiltinFunctionName.SIMPLE_QUERY_STRING.name(),
BuiltinFunctionName.QUERY_STRING.name()
);

/**
* Check if supplied function name is valid SingleFieldRelevanceFunction.
* @param funcName : Name of function
* @return : True if function is single-field function
*/
public static boolean isSingleFieldFunction(String funcName) {
return singleFieldFunctionNames.contains(funcName.toUpperCase());
}

/**
* Check if supplied function name is valid MultiFieldRelevanceFunction.
* @param funcName : Name of function
* @return : True if function is multi-field function
*/
public static boolean isMultiFieldFunction(String funcName) {
return multiFieldFunctionNames.contains(funcName.toUpperCase());
}

/**
* Verify if function queries fields available in type environment.
* @param node : Function used in query.
* @param context : Context of fields querying.
*/
public static void validateFieldList(FunctionExpression node, AnalysisContext context) {
String funcName = node.getFunctionName().toString();

TypeEnvironment typeEnv = context.peek();
if (isSingleFieldFunction(funcName)) {
node.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg ->
((arg.getArgName().equals("field")
&& !arg.getValue().toString().contains("*"))
)).findFirst().ifPresent(arg ->
typeEnv.resolve(new Symbol(Namespace.FIELD_NAME,
StringUtils.unquoteText(arg.getValue().toString()))
)
);
} else if (isMultiFieldFunction(funcName)) {
node.getArguments().stream().map(NamedArgumentExpression.class::cast).filter(arg ->
arg.getArgName().equals("fields")
).findFirst().ifPresent(fields ->
fields.getValue().valueOf(null).tupleValue()
.entrySet().stream().filter(k -> !(k.getKey().contains("*"))
).forEach(key -> typeEnv.resolve(new Symbol(Namespace.FIELD_NAME, key.getKey())))
);
}
}

/**
* Add functions specific to OpenSearch to repository.
*/
Expand Down
129 changes: 129 additions & 0 deletions core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.tuple.ImmutablePair;
Expand All @@ -71,11 +72,14 @@
import org.opensearch.sql.ast.expression.HighlightFunction;
import org.opensearch.sql.ast.expression.Literal;
import org.opensearch.sql.ast.expression.ParseMethod;
import org.opensearch.sql.ast.expression.RelevanceFieldList;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.ast.tree.AD;
import org.opensearch.sql.ast.tree.Kmeans;
import org.opensearch.sql.ast.tree.ML;
import org.opensearch.sql.ast.tree.RareTopN.CommandType;
import org.opensearch.sql.data.model.ExprTupleValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.exception.SemanticCheckException;
import org.opensearch.sql.expression.DSL;
Expand Down Expand Up @@ -265,6 +269,131 @@ public void analyze_filter_aggregation_relation() {
aggregate("MIN", qualifiedName("integer_value")), intLiteral(10))));
}

@Test
public void single_field_relevance_query_semantic_exception() {
SemanticCheckException exception =
assertThrows(
SemanticCheckException.class,
() ->
analyze(
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("match",
AstDSL.unresolvedArg("field", stringLiteral("missing_value")),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))))));
assertEquals(
"can't resolve Symbol(namespace=FIELD_NAME, name=missing_value) in type env",
exception.getMessage());
}

@Test
public void single_field_relevance_query() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
dsl.match(
dsl.namedArgument("field", DSL.literal("string_value")),
dsl.namedArgument("query", DSL.literal("query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("match",
AstDSL.unresolvedArg("field", stringLiteral("string_value")),
AstDSL.unresolvedArg("query", stringLiteral("query_value")))));
}

@Test
public void single_field_wildcard_relevance_query() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
dsl.match(
dsl.namedArgument("field", DSL.literal("wildcard_field*")),
dsl.namedArgument("query", DSL.literal("query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("match",
AstDSL.unresolvedArg("field", stringLiteral("wildcard_field*")),
AstDSL.unresolvedArg("query", stringLiteral("query_value")))));
}

@Test
public void multi_field_relevance_query_semantic_exception() {
SemanticCheckException exception =
assertThrows(
SemanticCheckException.class,
() ->
analyze(
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"missing_value1", 1.F, "missing_value2", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))))));
assertEquals(
"can't resolve Symbol(namespace=FIELD_NAME, name=missing_value1) in type env",
exception.getMessage());
}

@Test
public void multi_field_relevance_query_mixed_fields_semantic_exception() {
SemanticCheckException exception =
assertThrows(
SemanticCheckException.class,
() ->
analyze(
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"string_value", 1.F, "missing_value", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value"))))));
assertEquals(
"can't resolve Symbol(namespace=FIELD_NAME, name=missing_value) in type env",
exception.getMessage());
}

@Test
public void multi_field_relevance_query() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
dsl.query_string(
dsl.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"string_value", ExprValueUtils.floatValue(1.F),
"integer_value", ExprValueUtils.floatValue(.3F))
))
)),
dsl.namedArgument("query", DSL.literal("query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"string_value", 1.F, "integer_value", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value")))));
}

@Test
public void multi_field_wildcard_relevance_query() {
assertAnalyzeEqual(
LogicalPlanDSL.filter(
LogicalPlanDSL.relation("schema", table),
dsl.query_string(
dsl.namedArgument("fields", DSL.literal(
new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of(
"wildcard_field1*", ExprValueUtils.floatValue(1.F),
"wildcard_field2*", ExprValueUtils.floatValue(.3F))
))
)),
dsl.namedArgument("query", DSL.literal("query_value")))),
AstDSL.filter(
AstDSL.relation("schema"),
AstDSL.function("query_string",
AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of(
"wildcard_field1*", 1.F, "wildcard_field2*", .3F))),
AstDSL.unresolvedArg("query", stringLiteral("query_value")))));
}

@Test
public void rename_relation() {
assertAnalyzeEqual(
Expand Down
11 changes: 11 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/sql/MatchIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import org.json.JSONObject;
import org.junit.Test;
import org.opensearch.sql.legacy.SQLIntegTestCase;
import org.opensearch.sql.legacy.utils.StringUtils;

public class MatchIT extends SQLIntegTestCase {
@Override
Expand All @@ -35,4 +36,14 @@ public void match_in_having() throws IOException {
verifySchema(result, schema("lastname", "text"));
verifyDataRows(result, rows("Bates"));
}

@Test
public void missing_field_test() {
String query = StringUtils.format("SELECT * FROM %s WHERE match(invalid, 'Bates')", TEST_INDEX_ACCOUNT);
final RuntimeException exception =
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));
assertTrue(exception.getMessage()
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env") &&
exception.getMessage().contains("SemanticCheckException"));
}
}
11 changes: 11 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/sql/QueryStringIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import org.json.JSONObject;
import org.junit.Test;
import org.opensearch.sql.legacy.SQLIntegTestCase;
import org.opensearch.sql.legacy.utils.StringUtils;

public class QueryStringIT extends SQLIntegTestCase {
@Override
Expand Down Expand Up @@ -65,4 +66,14 @@ public void wildcard_test() throws IOException {
JSONObject result3 = executeJdbcRequest(query3);
assertEquals(10, result3.getInt("total"));
}

@Test
public void missing_field_test() {
String query = StringUtils.format("SELECT * FROM %s WHERE query_string([invalid], 'beer')", TEST_INDEX_BEER);
final RuntimeException exception =
expectThrows(RuntimeException.class, () -> executeJdbcRequest(query));
assertTrue(exception.getMessage()
.contains("can't resolve Symbol(namespace=FIELD_NAME, name=invalid) in type env") &&
exception.getMessage().contains("SemanticCheckException"));
}
}

0 comments on commit 641f751

Please sign in to comment.