Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BETWEEN expression in v2 engine #1163

Merged
merged 8 commits into from
Dec 16, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ public static RegexExpression regex(Expression sourceField, Expression pattern,
return new RegexExpression(sourceField, pattern, identifier);
}

public static FunctionExpression between(Expression... expressions) {
return compile(FunctionProperties.None, BuiltinFunctionName.BETWEEN, expressions);
}

public static PatternsExpression patterns(Expression sourceField, Expression pattern,
Expression identifier) {
return new PatternsExpression(sourceField, pattern, identifier);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ public enum BuiltinFunctionName {
GTE(FunctionName.of(">=")),
LIKE(FunctionName.of("like")),
NOT_LIKE(FunctionName.of("not like")),
BETWEEN(FunctionName.of("between")),

/**
* Aggregation Function.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ public static void register(BuiltinFunctionRepository repository) {
repository.register(like());
repository.register(notLike());
repository.register(regexp());
repository.register(between());
}

/**
Expand Down Expand Up @@ -253,6 +254,19 @@ private static DefaultFunctionResolver notLike() {
STRING));
}

private static DefaultFunctionResolver between() {
return FunctionDSL.define(BuiltinFunctionName.BETWEEN.getName(),
ExprCoreType.coreTypes().stream().map(
type -> FunctionDSL.impl(
FunctionDSL.nullMissingHandling((v1, v2, v3) ->
ExprBooleanValue.of(v1.compareTo(v2) >= 0 && v1.compareTo(v3) <= 0)),
BOOLEAN,
type,
type,
type))
.collect(Collectors.toList()));
}

private static ExprValue lookupTableFunction(ExprValue arg1, ExprValue arg2,
Table<ExprValue, ExprValue, ExprValue> table) {
if (table.contains(arg1, arg2)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -786,6 +786,29 @@ void testRegexpString(StringPatternPair stringPatternPair) {
.valueOf(valueEnv()).integerValue());
}

@Test
public void test_between() {
Object[][] testData = {
{false, 5, 10, 30},
{true, 10, 10, 30},
{true, 20, 10, 30},
{true, 30, 10, 30},
{false, 45, 10, 30},
{false, "a", "b", "e"},
{true, "c", "b", "e"},
};

for (Object[] data : testData) {
assertEquals(
data[0],
eval(DSL.between(
DSL.literal(fromObjectValue(data[1])),
DSL.literal(fromObjectValue(data[2])),
DSL.literal(fromObjectValue(data[3])))),
String.format("Failed on test data: %s", Arrays.toString(data)));
}
}

/**
* Todo. remove this test cases after script serilization implemented.
*/
Expand Down Expand Up @@ -819,4 +842,8 @@ public void compare_int_long() {
FunctionExpression equal = DSL.equal(DSL.literal(1), DSL.literal(1L));
assertTrue(equal.valueOf(valueEnv()).booleanValue());
}

private boolean eval(Expression expr) {
return expr.valueOf().booleanValue();
}
}
1 change: 1 addition & 0 deletions sql/src/main/antlr/OpenSearchSQLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ predicate
: expressionAtom #expressionAtomPredicate
| left=predicate comparisonOperator right=predicate #binaryComparisonPredicate
| predicate IS nullNotnull #isNullPredicate
| predicate BETWEEN predicate AND predicate #betweenPredicate
dai-chen marked this conversation as resolved.
Show resolved Hide resolved
| left=predicate NOT? LIKE right=predicate #likePredicate
| left=predicate REGEXP right=predicate #regexpPredicate
| predicate NOT? IN '(' expressions ')' #inPredicate
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@

import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName;
import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.BETWEEN;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.LIKE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.NOT_LIKE;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.POSITION;
import static org.opensearch.sql.expression.function.BuiltinFunctionName.REGEXP;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AlternateMultiMatchFieldContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AlternateMultiMatchQueryContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.BetweenPredicateContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.BinaryComparisonPredicateContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.BooleanContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.CaseFuncAlternativeContext;
Expand All @@ -24,16 +28,25 @@
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DataTypeFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DateLiteralContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DistinctCountFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FilterClauseContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FilteredAggregationFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FunctionArgContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.HighlightFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.InPredicateContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IsNullPredicateContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.LikePredicateContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MathExpressionAtomContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.MultiFieldRelevanceFunctionContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NoFieldRelevanceFunctionContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NotExpressionContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NullLiteralContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.OverClauseContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.PositionFunctionContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.QualifiedNameContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.RegexpPredicateContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.RegularAggregateFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.RelevanceArgContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.RelevanceFieldAndWeightContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ScalarFunctionCallContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ScalarWindowFunctionContext;
import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ShowDescribePatternContext;
Expand Down Expand Up @@ -82,7 +95,6 @@
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.expression.function.BuiltinFunctionName;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AndExpressionContext;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ColumnNameContext;
import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IdentContext;
Expand Down Expand Up @@ -137,7 +149,7 @@ public UnresolvedExpression visitScalarFunctionCall(ScalarFunctionCallContext ct

@Override
public UnresolvedExpression visitHighlightFunctionCall(
OpenSearchSQLParser.HighlightFunctionCallContext ctx) {
HighlightFunctionCallContext ctx) {
ImmutableMap.Builder<String, Literal> builder = ImmutableMap.builder();
ctx.highlightFunction().highlightArg().forEach(v -> builder.put(
v.highlightArgName().getText().toLowerCase(),
Expand All @@ -151,7 +163,7 @@ public UnresolvedExpression visitHighlightFunctionCall(

@Override
public UnresolvedExpression visitPositionFunction(
OpenSearchSQLParser.PositionFunctionContext ctx) {
PositionFunctionContext ctx) {
return new Function(
POSITION.getName().getFunctionName(),
Arrays.asList(visitFunctionArg(ctx.functionArg(0)),
Expand Down Expand Up @@ -184,7 +196,7 @@ public UnresolvedExpression visitShowDescribePattern(

@Override
public UnresolvedExpression visitFilteredAggregationFunctionCall(
OpenSearchSQLParser.FilteredAggregationFunctionCallContext ctx) {
FilteredAggregationFunctionCallContext ctx) {
AggregateFunction agg = (AggregateFunction) visit(ctx.aggregateFunction());
return agg.condition(visit(ctx.filterClause()));
}
Expand Down Expand Up @@ -241,7 +253,7 @@ public UnresolvedExpression visitCountStarFunctionCall(CountStarFunctionCallCont
}

@Override
public UnresolvedExpression visitFilterClause(OpenSearchSQLParser.FilterClauseContext ctx) {
public UnresolvedExpression visitFilterClause(FilterClauseContext ctx) {
return visit(ctx.expression());
}

Expand All @@ -253,6 +265,14 @@ public UnresolvedExpression visitIsNullPredicate(IsNullPredicateContext ctx) {
Arrays.asList(visit(ctx.predicate())));
}

@Override
public UnresolvedExpression visitBetweenPredicate(BetweenPredicateContext ctx) {
return new Function(BETWEEN.getName().getFunctionName(),
ctx.predicate().stream()
.map(this::visit)
.collect(Collectors.toList()));
}

@Override
public UnresolvedExpression visitLikePredicate(LikePredicateContext ctx) {
return new Function(
Expand All @@ -268,7 +288,7 @@ public UnresolvedExpression visitRegexpPredicate(RegexpPredicateContext ctx) {
}

@Override
public UnresolvedExpression visitInPredicate(OpenSearchSQLParser.InPredicateContext ctx) {
public UnresolvedExpression visitInPredicate(InPredicateContext ctx) {
UnresolvedExpression field = visit(ctx.predicate());
List<UnresolvedExpression> inLists = ctx
.expressions()
Expand Down Expand Up @@ -392,7 +412,7 @@ public UnresolvedExpression visitConvertedDataType(

@Override
public UnresolvedExpression visitNoFieldRelevanceFunction(
OpenSearchSQLParser.NoFieldRelevanceFunctionContext ctx) {
NoFieldRelevanceFunctionContext ctx) {
return new Function(
ctx.noFieldRelevanceFunctionName().getText().toLowerCase(),
noFieldRelevanceArguments(ctx));
Expand All @@ -415,7 +435,7 @@ public UnresolvedExpression visitMultiFieldRelevanceFunction(
if ((funcName.equalsIgnoreCase(BuiltinFunctionName.MULTI_MATCH.toString())
|| funcName.equalsIgnoreCase(BuiltinFunctionName.MULTIMATCH.toString())
|| funcName.equalsIgnoreCase(BuiltinFunctionName.MULTIMATCHQUERY.toString()))
&& ! ctx.getRuleContexts(OpenSearchSQLParser.AlternateMultiMatchQueryContext.class)
&& ! ctx.getRuleContexts(AlternateMultiMatchQueryContext.class)
.isEmpty()) {
return new Function(
ctx.multiFieldRelevanceFunctionName().getText().toLowerCase(),
Expand All @@ -428,7 +448,7 @@ public UnresolvedExpression visitMultiFieldRelevanceFunction(
}

private Function buildFunction(String functionName,
List<OpenSearchSQLParser.FunctionArgContext> arg) {
List<FunctionArgContext> arg) {
return new Function(
functionName,
arg
Expand All @@ -447,7 +467,7 @@ private QualifiedName visitIdentifiers(List<IdentContext> identifiers) {
);
}

private void fillRelevanceArgs(List<OpenSearchSQLParser.RelevanceArgContext> args,
private void fillRelevanceArgs(List<RelevanceArgContext> args,
ImmutableList.Builder<UnresolvedExpression> builder) {
// To support old syntax we must support argument keys as quoted strings.
args.forEach(v -> builder.add(v.argName == null
Expand All @@ -459,7 +479,7 @@ private void fillRelevanceArgs(List<OpenSearchSQLParser.RelevanceArgContext> arg
}

private List<UnresolvedExpression> noFieldRelevanceArguments(
OpenSearchSQLParser.NoFieldRelevanceFunctionContext ctx) {
NoFieldRelevanceFunctionContext ctx) {
// all the arguments are defaulted to string values
// to skip environment resolving and function signature resolving
ImmutableList.Builder<UnresolvedExpression> builder = ImmutableList.builder();
Expand All @@ -470,7 +490,7 @@ private List<UnresolvedExpression> noFieldRelevanceArguments(
}

private List<UnresolvedExpression> singleFieldRelevanceArguments(
OpenSearchSQLParser.SingleFieldRelevanceFunctionContext ctx) {
SingleFieldRelevanceFunctionContext ctx) {
// all the arguments are defaulted to string values
// to skip environment resolving and function signature resolving
ImmutableList.Builder<UnresolvedExpression> builder = ImmutableList.builder();
Expand All @@ -485,12 +505,12 @@ private List<UnresolvedExpression> singleFieldRelevanceArguments(


private List<UnresolvedExpression> multiFieldRelevanceArguments(
OpenSearchSQLParser.MultiFieldRelevanceFunctionContext ctx) {
MultiFieldRelevanceFunctionContext ctx) {
// all the arguments are defaulted to string values
// to skip environment resolving and function signature resolving
ImmutableList.Builder<UnresolvedExpression> builder = ImmutableList.builder();
var fields = new RelevanceFieldList(ctx
.getRuleContexts(OpenSearchSQLParser.RelevanceFieldAndWeightContext.class)
.getRuleContexts(RelevanceFieldAndWeightContext.class)
.stream()
.collect(Collectors.toMap(
f -> StringUtils.unquoteText(f.field.getText()),
Expand All @@ -509,14 +529,14 @@ private List<UnresolvedExpression> multiFieldRelevanceArguments(
* @return : Returns list of all arguments for relevance function.
*/
private List<UnresolvedExpression> alternateMultiMatchArguments(
OpenSearchSQLParser.MultiFieldRelevanceFunctionContext ctx) {
MultiFieldRelevanceFunctionContext ctx) {
// all the arguments are defaulted to string values
// to skip environment resolving and function signature resolving
ImmutableList.Builder<UnresolvedExpression> builder = ImmutableList.builder();
Map<String, Float> fieldAndWeightMap = new HashMap<>();

String[] fieldAndWeights = StringUtils.unquoteText(
ctx.getRuleContexts(OpenSearchSQLParser.AlternateMultiMatchFieldContext.class)
ctx.getRuleContexts(AlternateMultiMatchFieldContext.class)
.stream().findFirst().get().argVal.getText()).split(",");

for (var fieldAndWeight : fieldAndWeights) {
Expand All @@ -527,7 +547,7 @@ private List<UnresolvedExpression> alternateMultiMatchArguments(
builder.add(new UnresolvedArgument("fields",
new RelevanceFieldList(fieldAndWeightMap)));

ctx.getRuleContexts(OpenSearchSQLParser.AlternateMultiMatchQueryContext.class)
ctx.getRuleContexts(AlternateMultiMatchQueryContext.class)
.stream().findFirst().ifPresent(
arg ->
builder.add(new UnresolvedArgument("query",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,14 @@ public void canBuildRegexpExpression() {
);
}

@Test
public void canBuildBetweenExpression() {
assertEquals(
function("between", qualifiedName("age"), intLiteral(10), intLiteral(30)),
buildExprAst("age BETWEEN 10 AND 30")
);
}

@Test
public void canBuildLogicalExpression() {
assertEquals(
Expand Down