From 66143065c57763d788a218e2439a48144687055b Mon Sep 17 00:00:00 2001 From: Max Ksyunz Date: Tue, 8 Nov 2022 12:33:36 -0800 Subject: [PATCH] Use query execution start time as the value of now-like functions. (#149) - Add FunctionProperties interface to capture query metadata and provide to function implementations. - Update FunctionBuilder to take FunctionProperties in addition to function arguments when evaluating a SQL function. - Implement now-like functions using FunctionProperties. - Add FunctionDSL.nullMissingHandlingWithProperties to allow for consistent null and missing value handling across all functions. - Remove constant value caching from ExpressionAnalyzer. It is no longer necessary in ExpressionAnalyzer -- the same behavior is now implemented with FunctionProperties. ### Unit Tests - Adjust getQueryStartClock_differs_from_instantNow unit test. On Windows, getQueryStartClock_differs_from_instantNow fails because Instant.now() in the test returns the same value as Instant.now() called for FunctionProperties construction. - Add unit tests for FunctionDSL. - Use Spring to instantiate dsl in OpenSearchTestBase. Signed-off-by: MaxKsyunz --- .../sql/analysis/AnalysisContext.java | 12 - .../sql/analysis/ExpressionAnalyzer.java | 15 -- .../sql/ast/AbstractNodeVisitor.java | 5 - .../org/opensearch/sql/ast/dsl/AstDSL.java | 5 - .../sql/ast/expression/ConstantFunction.java | 28 --- .../org/opensearch/sql/expression/DSL.java | 2 +- .../aggregation/AggregatorFunction.java | 58 ++--- .../expression/config/ExpressionConfig.java | 15 +- .../expression/datetime/DateTimeFunction.java | 173 ++++++------- .../function/BuiltinFunctionRepository.java | 25 +- .../expression/function/FunctionBuilder.java | 2 +- .../sql/expression/function/FunctionDSL.java | 116 ++++++--- .../function/FunctionProperties.java | 45 ++++ .../function/RelevanceFunctionResolver.java | 4 +- .../expression/system/SystemFunctions.java | 3 +- .../expression/window/WindowFunctions.java | 3 +- .../sql/analysis/AnalyzerTestBase.java | 3 +- .../sql/analysis/ExpressionAnalyzerTest.java | 15 -- .../expression/ExpressionNodeVisitorTest.java | 10 +- .../sql/expression/ExpressionTestBase.java | 4 + .../config/ExpressionConfigTest.java | 31 +++ .../expression/datetime/DateTimeTestBase.java | 98 ++++---- .../sql/expression/datetime/MakeDateTest.java | 2 - .../sql/expression/datetime/MakeTimeTest.java | 4 - .../datetime/NowLikeFunctionTest.java | 192 ++++++++++----- .../datetime/UnixTimeStampTest.java | 6 +- .../datetime/UnixTwoWayConversionTest.java | 28 +-- .../BuiltinFunctionRepositoryTest.java | 41 ++-- .../function/FunctionDSLDefineTest.java | 71 ++++++ .../function/FunctionDSLTestBase.java | 62 +++++ .../function/FunctionDSLimplNoArgTest.java | 31 +++ .../function/FunctionDSLimplOneArgTest.java | 32 +++ .../function/FunctionDSLimplTestBase.java | 89 +++++++ .../function/FunctionDSLimplThreeArgTest.java | 32 +++ .../function/FunctionDSLimplTwoArgTest.java | 32 +++ ...nctionDSLimplWithPropertiesNoArgsTest.java | 29 +++ ...nctionDSLimplWithPropertiesOneArgTest.java | 34 +++ .../FunctionDSLnullMissingHandlingTest.java | 109 +++++++++ .../function/FunctionPropertiesTest.java | 69 ++++++ .../convert/TypeCastOperatorTest.java | 10 +- .../system/SystemFunctionsTest.java | 12 +- opensearch/build.gradle | 1 + .../sql/opensearch/OpenSearchTestBase.java | 23 ++ .../OpenSearchDataTypeRecognitionTest.java | 6 +- .../logical/OpenSearchLogicOptimizerTest.java | 6 +- .../storage/OpenSearchIndexTest.java | 7 +- .../AggregationQueryBuilderTest.java | 5 +- .../ExpressionAggregationScriptTest.java | 5 +- .../dsl/MetricAggregationBuilderTest.java | 4 +- .../filter/ExpressionFilterScriptTest.java | 5 +- .../script/filter/FilterQueryBuilderTest.java | 229 +++++++++--------- .../script/filter/lucene/LuceneQueryTest.java | 4 +- .../lucene/MatchBoolPrefixQueryTest.java | 41 ++-- .../lucene/MatchPhrasePrefixQueryTest.java | 46 ++-- .../filter/lucene/MatchPhraseQueryTest.java | 40 +-- .../script/filter/lucene/MatchQueryTest.java | 98 ++++---- .../script/filter/lucene/MultiMatchTest.java | 100 ++++---- .../script/filter/lucene/QueryStringTest.java | 84 ++++--- .../filter/lucene/SimpleQueryStringTest.java | 140 +++++------ .../lucene/relevance/MultiFieldQueryTest.java | 11 +- .../relevance/SingleFieldQueryTest.java | 5 +- .../script/sort/SortQueryBuilderTest.java | 8 +- .../DefaultExpressionSerializerTest.java | 6 +- ppl/src/main/antlr/OpenSearchPPLParser.g4 | 67 +++-- .../sql/ppl/parser/AstExpressionBuilder.java | 31 +-- .../ppl/parser/AstNowLikeFunctionTest.java | 101 ++++---- prometheus/build.gradle | 1 + .../QueryRangeTableFunctionResolver.java | 2 +- .../QueryRangeTableFunctionResolverTest.java | 16 +- .../logical/PrometheusLogicOptimizerTest.java | 10 +- .../storage/PrometheusMetricTableTest.java | 12 +- sql/src/main/antlr/OpenSearchSQLParser.g4 | 65 +++-- .../sql/sql/parser/AstExpressionBuilder.java | 26 +- .../sql/sql/parser/AstBuilderTest.java | 66 +---- .../sql/sql/parser/AstBuilderTestBase.java | 22 ++ .../sql/parser/AstNowLikeFunctionTest.java | 98 ++++++++ 76 files changed, 1903 insertions(+), 1045 deletions(-) delete mode 100644 core/src/main/java/org/opensearch/sql/ast/expression/ConstantFunction.java create mode 100644 core/src/main/java/org/opensearch/sql/expression/function/FunctionProperties.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/config/ExpressionConfigTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLDefineTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplNoArgTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplOneArgTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTestBase.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplThreeArgTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTwoArgTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesNoArgsTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesOneArgTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLnullMissingHandlingTest.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/function/FunctionPropertiesTest.java create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/OpenSearchTestBase.java create mode 100644 sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTestBase.java create mode 100644 sql/src/test/java/org/opensearch/sql/sql/parser/AstNowLikeFunctionTest.java diff --git a/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java b/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java index f3fd623371..4ad1997e04 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java +++ b/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java @@ -7,12 +7,9 @@ package org.opensearch.sql.analysis; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Objects; import lombok.Getter; -import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.NamedExpression; /** @@ -26,14 +23,6 @@ public class AnalysisContext { @Getter private final List namedParseExpressions; - /** - * Storage for values of functions which return a constant value. - * We are storing the values there to use it in sequential calls to those functions. - * For example, `now` function should the same value during processing a query. - */ - @Getter - private final Map constantFunctionValues; - public AnalysisContext() { this(new TypeEnvironment(null)); } @@ -45,7 +34,6 @@ public AnalysisContext() { public AnalysisContext(TypeEnvironment environment) { this.environment = environment; this.namedParseExpressions = new ArrayList<>(); - this.constantFunctionValues = new HashMap<>(); } /** diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index 061c4b505f..9e90256f4b 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -22,11 +22,9 @@ import org.opensearch.sql.ast.expression.AggregateFunction; import org.opensearch.sql.ast.expression.AllFields; import org.opensearch.sql.ast.expression.And; -import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.ConstantFunction; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; @@ -176,19 +174,6 @@ public Expression visitRelevanceFieldList(RelevanceFieldList node, AnalysisConte ImmutableMap.copyOf(node.getFieldList()))); } - @Override - public Expression visitConstantFunction(ConstantFunction node, AnalysisContext context) { - var valueName = node.getFuncName(); - if (context.getConstantFunctionValues().containsKey(valueName)) { - return context.getConstantFunctionValues().get(valueName); - } - - var value = visitFunction(node, context); - value = DSL.literal(value.valueOf(null)); - context.getConstantFunctionValues().put(valueName, value); - return value; - } - @Override public Expression visitFunction(Function node, AnalysisContext context) { FunctionName functionName = FunctionName.of(node.getFuncName()); diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 60e7d6f06e..1a9467313f 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -15,7 +15,6 @@ import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.ConstantFunction; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; @@ -123,10 +122,6 @@ public T visitRelevanceFieldList(RelevanceFieldList node, C context) { return visitChildren(node, context); } - public T visitConstantFunction(ConstantFunction node, C context) { - return visitChildren(node, context); - } - public T visitUnresolvedAttribute(UnresolvedAttribute node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 24ada1fd92..2959cae4a1 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -19,7 +19,6 @@ import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.ConstantFunction; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; @@ -246,10 +245,6 @@ public static Function function(String funcName, UnresolvedExpression... funcArg return new Function(funcName, Arrays.asList(funcArgs)); } - public static Function constantFunction(String funcName, UnresolvedExpression... funcArgs) { - return new ConstantFunction(funcName, Arrays.asList(funcArgs)); - } - /** * CASE * WHEN search_condition THEN result_expr diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/ConstantFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/ConstantFunction.java deleted file mode 100644 index f14e65eeb2..0000000000 --- a/core/src/main/java/org/opensearch/sql/ast/expression/ConstantFunction.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.ast.expression; - -import java.util.List; -import lombok.EqualsAndHashCode; -import org.opensearch.sql.ast.AbstractNodeVisitor; - -/** - * Expression node that holds a function which should be replaced by its constant[1] value. - * [1] Constant at execution time. - */ -@EqualsAndHashCode(callSuper = false) -public class ConstantFunction extends Function { - - public ConstantFunction(String funcName, List funcArgs) { - super(funcName, funcArgs); - } - - @Override - public R accept(AbstractNodeVisitor nodeVisitor, C context) { - return nodeVisitor.visitConstantFunction(this, context); - } -} diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index e65dbd6fcb..512784174f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -124,7 +124,7 @@ public static NamedArgumentExpression namedArgument(String argName, Expression v return new NamedArgumentExpression(argName, value); } - public NamedArgumentExpression namedArgument(String name, String value) { + public static NamedArgumentExpression namedArgument(String name, String value) { return namedArgument(name, literal(value)); } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index 9fbf1557aa..c30ca13351 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -68,7 +68,7 @@ private static DefaultFunctionResolver avg() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> new AvgAggregator(arguments, DOUBLE)) + (functionProperties, arguments) -> new AvgAggregator(arguments, DOUBLE)) .build() ); } @@ -78,7 +78,7 @@ private static DefaultFunctionResolver count() { DefaultFunctionResolver functionResolver = new DefaultFunctionResolver(functionName, ExprCoreType.coreTypes().stream().collect(Collectors.toMap( type -> new FunctionSignature(functionName, Collections.singletonList(type)), - type -> arguments -> new CountAggregator(arguments, INTEGER)))); + type -> (functionProperties, arguments) -> new CountAggregator(arguments, INTEGER)))); return functionResolver; } @@ -88,13 +88,13 @@ private static DefaultFunctionResolver sum() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), - arguments -> new SumAggregator(arguments, INTEGER)) + (functionProperties, arguments) -> new SumAggregator(arguments, INTEGER)) .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), - arguments -> new SumAggregator(arguments, LONG)) + (functionProperties, arguments) -> new SumAggregator(arguments, LONG)) .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), - arguments -> new SumAggregator(arguments, FLOAT)) + (functionProperties, arguments) -> new SumAggregator(arguments, FLOAT)) .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> new SumAggregator(arguments, DOUBLE)) + (functionProperties, arguments) -> new SumAggregator(arguments, DOUBLE)) .build() ); } @@ -105,23 +105,23 @@ private static DefaultFunctionResolver min() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), - arguments -> new MinAggregator(arguments, INTEGER)) + (functionProperties, arguments) -> new MinAggregator(arguments, INTEGER)) .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), - arguments -> new MinAggregator(arguments, LONG)) + (functionProperties, arguments) -> new MinAggregator(arguments, LONG)) .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), - arguments -> new MinAggregator(arguments, FLOAT)) + (functionProperties, arguments) -> new MinAggregator(arguments, FLOAT)) .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> new MinAggregator(arguments, DOUBLE)) + (functionProperties, arguments) -> new MinAggregator(arguments, DOUBLE)) .put(new FunctionSignature(functionName, Collections.singletonList(STRING)), - arguments -> new MinAggregator(arguments, STRING)) + (functionProperties, arguments) -> new MinAggregator(arguments, STRING)) .put(new FunctionSignature(functionName, Collections.singletonList(DATE)), - arguments -> new MinAggregator(arguments, DATE)) + (functionProperties, arguments) -> new MinAggregator(arguments, DATE)) .put(new FunctionSignature(functionName, Collections.singletonList(DATETIME)), - arguments -> new MinAggregator(arguments, DATETIME)) + (functionProperties, arguments) -> new MinAggregator(arguments, DATETIME)) .put(new FunctionSignature(functionName, Collections.singletonList(TIME)), - arguments -> new MinAggregator(arguments, TIME)) + (functionProperties, arguments) -> new MinAggregator(arguments, TIME)) .put(new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), - arguments -> new MinAggregator(arguments, TIMESTAMP)) + (functionProperties, arguments) -> new MinAggregator(arguments, TIMESTAMP)) .build()); } @@ -131,23 +131,23 @@ private static DefaultFunctionResolver max() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), - arguments -> new MaxAggregator(arguments, INTEGER)) + (functionProperties, arguments) -> new MaxAggregator(arguments, INTEGER)) .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), - arguments -> new MaxAggregator(arguments, LONG)) + (functionProperties, arguments) -> new MaxAggregator(arguments, LONG)) .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), - arguments -> new MaxAggregator(arguments, FLOAT)) + (functionProperties, arguments) -> new MaxAggregator(arguments, FLOAT)) .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> new MaxAggregator(arguments, DOUBLE)) + (functionProperties, arguments) -> new MaxAggregator(arguments, DOUBLE)) .put(new FunctionSignature(functionName, Collections.singletonList(STRING)), - arguments -> new MaxAggregator(arguments, STRING)) + (functionProperties, arguments) -> new MaxAggregator(arguments, STRING)) .put(new FunctionSignature(functionName, Collections.singletonList(DATE)), - arguments -> new MaxAggregator(arguments, DATE)) + (functionProperties, arguments) -> new MaxAggregator(arguments, DATE)) .put(new FunctionSignature(functionName, Collections.singletonList(DATETIME)), - arguments -> new MaxAggregator(arguments, DATETIME)) + (functionProperties, arguments) -> new MaxAggregator(arguments, DATETIME)) .put(new FunctionSignature(functionName, Collections.singletonList(TIME)), - arguments -> new MaxAggregator(arguments, TIME)) + (functionProperties, arguments) -> new MaxAggregator(arguments, TIME)) .put(new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), - arguments -> new MaxAggregator(arguments, TIMESTAMP)) + (functionProperties, arguments) -> new MaxAggregator(arguments, TIMESTAMP)) .build() ); } @@ -158,7 +158,7 @@ private static DefaultFunctionResolver varSamp() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> varianceSample(arguments, DOUBLE)) + (functionProperties, arguments) -> varianceSample(arguments, DOUBLE)) .build() ); } @@ -169,7 +169,7 @@ private static DefaultFunctionResolver varPop() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> variancePopulation(arguments, DOUBLE)) + (functionProperties, arguments) -> variancePopulation(arguments, DOUBLE)) .build() ); } @@ -180,7 +180,7 @@ private static DefaultFunctionResolver stddevSamp() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> stddevSample(arguments, DOUBLE)) + (functionProperties, arguments) -> stddevSample(arguments, DOUBLE)) .build() ); } @@ -191,7 +191,7 @@ private static DefaultFunctionResolver stddevPop() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> stddevPopulation(arguments, DOUBLE)) + (functionProperties, arguments) -> stddevPopulation(arguments, DOUBLE)) .build() ); } @@ -201,7 +201,7 @@ private static DefaultFunctionResolver take() { DefaultFunctionResolver functionResolver = new DefaultFunctionResolver(functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, ImmutableList.of(STRING, INTEGER)), - arguments -> new TakeAggregator(arguments, ARRAY)) + (functionProperties, arguments) -> new TakeAggregator(arguments, ARRAY)) .build()); return functionResolver; } diff --git a/core/src/main/java/org/opensearch/sql/expression/config/ExpressionConfig.java b/core/src/main/java/org/opensearch/sql/expression/config/ExpressionConfig.java index c68086ab4d..ccae09d050 100644 --- a/core/src/main/java/org/opensearch/sql/expression/config/ExpressionConfig.java +++ b/core/src/main/java/org/opensearch/sql/expression/config/ExpressionConfig.java @@ -6,12 +6,16 @@ package org.opensearch.sql.expression.config; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneId; import java.util.HashMap; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.aggregation.AggregatorFunction; import org.opensearch.sql.expression.datetime.DateTimeFunction; import org.opensearch.sql.expression.datetime.IntervalClause; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.FunctionProperties; import org.opensearch.sql.expression.function.OpenSearchFunctions; import org.opensearch.sql.expression.operator.arthmetic.ArithmeticFunction; import org.opensearch.sql.expression.operator.arthmetic.MathematicalFunction; @@ -33,9 +37,10 @@ public class ExpressionConfig { * BuiltinFunctionRepository constructor. */ @Bean - public BuiltinFunctionRepository functionRepository() { + public BuiltinFunctionRepository functionRepository(FunctionProperties functionContext) { + BuiltinFunctionRepository builtinFunctionRepository = - new BuiltinFunctionRepository(new HashMap<>()); + new BuiltinFunctionRepository(new HashMap<>(), functionContext); ArithmeticFunction.register(builtinFunctionRepository); BinaryPredicateOperator.register(builtinFunctionRepository); MathematicalFunction.register(builtinFunctionRepository); @@ -51,8 +56,14 @@ public BuiltinFunctionRepository functionRepository() { return builtinFunctionRepository; } + @Bean + public FunctionProperties functionExecutionContext() { + return new FunctionProperties(Instant.now(), ZoneId.systemDefault()); + } + @Bean public DSL dsl(BuiltinFunctionRepository repository) { return new DSL(repository); } + } diff --git a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java index 42274c0756..0d194e7197 100644 --- a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java @@ -18,6 +18,7 @@ import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; import static org.opensearch.sql.expression.function.FunctionDSL.define; import static org.opensearch.sql.expression.function.FunctionDSL.impl; +import static org.opensearch.sql.expression.function.FunctionDSL.implWithProperties; import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling; import static org.opensearch.sql.utils.DateTimeFormatters.DATE_FORMATTER_LONG_YEAR; import static org.opensearch.sql.utils.DateTimeFormatters.DATE_FORMATTER_SHORT_YEAR; @@ -28,6 +29,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.text.DecimalFormat; +import java.time.Clock; import java.time.DateTimeException; import java.time.Instant; import java.time.LocalDate; @@ -42,7 +44,6 @@ import java.util.Locale; import java.util.TimeZone; import java.util.concurrent.TimeUnit; -import javax.annotation.Nullable; import lombok.experimental.UtilityClass; import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprDatetimeValue; @@ -60,7 +61,9 @@ import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.DefaultFunctionResolver; +import org.opensearch.sql.expression.function.FunctionDSL; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.utils.DateTimeUtils; /** @@ -128,6 +131,78 @@ public void register(BuiltinFunctionRepository repository) { repository.register(year()); } + /** + * NOW() returns a constant time that indicates the time at which the statement began to execute. + * `fsp` argument support is removed until refactoring to avoid bug where `now()`, `now(x)` and + * `now(y) return different values. + */ + private FunctionResolver now(FunctionName functionName) { + return define(functionName, + implWithProperties( + functionProperties -> new ExprDatetimeValue( + formatNow(functionProperties.getQueryStartClock())), DATETIME) + ); + } + + private FunctionResolver now() { + return now(BuiltinFunctionName.NOW.getName()); + } + + private FunctionResolver current_timestamp() { + return now(BuiltinFunctionName.CURRENT_TIMESTAMP.getName()); + } + + private FunctionResolver localtimestamp() { + return now(BuiltinFunctionName.LOCALTIMESTAMP.getName()); + } + + private FunctionResolver localtime() { + return now(BuiltinFunctionName.LOCALTIME.getName()); + } + + /** + * SYSDATE() returns the time at which it executes. + */ + private FunctionResolver sysdate() { + return define(BuiltinFunctionName.SYSDATE.getName(), + implWithProperties(functionProperties + -> new ExprDatetimeValue(formatNow(Clock.systemDefaultZone())), DATETIME), + FunctionDSL.implWithProperties((functionProperties, v) -> new ExprDatetimeValue( + formatNow(Clock.systemDefaultZone(), v.integerValue())), DATETIME, INTEGER) + ); + } + + /** + * Synonym for @see `now`. + */ + private FunctionResolver curtime(FunctionName functionName) { + return define(functionName, + implWithProperties(functionProperties -> new ExprTimeValue( + formatNow(functionProperties.getQueryStartClock()).toLocalTime()), TIME)); + } + + private FunctionResolver curtime() { + return curtime(BuiltinFunctionName.CURTIME.getName()); + } + + private FunctionResolver current_time() { + return curtime(BuiltinFunctionName.CURRENT_TIME.getName()); + } + + private FunctionResolver curdate(FunctionName functionName) { + return define(functionName, + implWithProperties(functionProperties -> new ExprDateValue( + formatNow(functionProperties.getQueryStartClock()).toLocalDate()), DATE)); + } + + private FunctionResolver curdate() { + return curdate(BuiltinFunctionName.CURDATE.getName()); + } + + private FunctionResolver current_date() { + return curdate(BuiltinFunctionName.CURRENT_DATE.getName()); + } + /** * Specify a start date and add a temporal amount to the date. * The return type depends on the date type and the interval unit. Detailed supported signatures: @@ -170,41 +245,6 @@ private DefaultFunctionResolver convert_tz() { ); } - private DefaultFunctionResolver curdate(FunctionName functionName) { - return define(functionName, - impl(() -> new ExprDateValue(formatNow(null).toLocalDate()), DATE) - ); - } - - private DefaultFunctionResolver curdate() { - return curdate(BuiltinFunctionName.CURDATE.getName()); - } - - /** - * Synonym for @see `now`. - */ - private DefaultFunctionResolver curtime(FunctionName functionName) { - return define(functionName, - impl(() -> new ExprTimeValue(formatNow(null).toLocalTime()), TIME) - ); - } - - private DefaultFunctionResolver curtime() { - return curtime(BuiltinFunctionName.CURTIME.getName()); - } - - private DefaultFunctionResolver current_date() { - return curdate(BuiltinFunctionName.CURRENT_DATE.getName()); - } - - private DefaultFunctionResolver current_time() { - return curtime(BuiltinFunctionName.CURRENT_TIME.getName()); - } - - private DefaultFunctionResolver current_timestamp() { - return now(BuiltinFunctionName.CURRENT_TIMESTAMP.getName()); - } - /** * Extracts the date part of a date and time value. * Also to construct a date type. The supported signatures: @@ -224,7 +264,7 @@ private DefaultFunctionResolver date() { * (STRING, STRING) -> DATETIME * (STRING) -> DATETIME */ - private DefaultFunctionResolver datetime() { + private FunctionResolver datetime() { return define(BuiltinFunctionName.DATETIME.getName(), impl(nullMissingHandling(DateTimeFunction::exprDateTime), DATETIME, STRING, STRING), @@ -336,7 +376,7 @@ private DefaultFunctionResolver from_days() { impl(nullMissingHandling(DateTimeFunction::exprFromDays), DATE, LONG)); } - private DefaultFunctionResolver from_unixtime() { + private FunctionResolver from_unixtime() { return define(BuiltinFunctionName.FROM_UNIXTIME.getName(), impl(nullMissingHandling(DateTimeFunction::exprFromUnixTime), DATETIME, DOUBLE), impl(nullMissingHandling(DateTimeFunction::exprFromUnixTimeFormat), @@ -355,35 +395,12 @@ private DefaultFunctionResolver hour() { ); } - private DefaultFunctionResolver localtime() { - return now(BuiltinFunctionName.LOCALTIME.getName()); - } - - private DefaultFunctionResolver localtimestamp() { - return now(BuiltinFunctionName.LOCALTIMESTAMP.getName()); - } - - /** - * NOW() returns a constant time that indicates the time at which the statement began to execute. - * `fsp` argument support is removed until refactoring to avoid bug where `now()`, `now(x)` and - * `now(y) return different values. - */ - private DefaultFunctionResolver now(FunctionName functionName) { - return define(functionName, - impl(() -> new ExprDatetimeValue(formatNow(null)), DATETIME) - ); - } - - private DefaultFunctionResolver now() { - return now(BuiltinFunctionName.NOW.getName()); - } - - private DefaultFunctionResolver makedate() { + private FunctionResolver makedate() { return define(BuiltinFunctionName.MAKEDATE.getName(), impl(nullMissingHandling(DateTimeFunction::exprMakeDate), DATE, DOUBLE, DOUBLE)); } - private DefaultFunctionResolver maketime() { + private FunctionResolver maketime() { return define(BuiltinFunctionName.MAKETIME.getName(), impl(nullMissingHandling(DateTimeFunction::exprMakeTime), TIME, DOUBLE, DOUBLE, DOUBLE)); } @@ -535,9 +552,10 @@ private DefaultFunctionResolver to_days() { impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, DATETIME)); } - private DefaultFunctionResolver unix_timestamp() { + private FunctionResolver unix_timestamp() { return define(BuiltinFunctionName.UNIX_TIMESTAMP.getName(), - impl(DateTimeFunction::unixTimeStamp, LONG), + implWithProperties(functionProperties + -> DateTimeFunction.unixTimeStamp(functionProperties.getQueryStartClock()), LONG), impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, DATE), impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, DATETIME), impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, TIMESTAMP), @@ -1004,16 +1022,6 @@ private ExprValue exprSubDateInterval(ExprValue date, ExprValue expr) { : exprValue); } - /** - * SYSDATE() returns the time at which it executes. - */ - private DefaultFunctionResolver sysdate() { - return define(BuiltinFunctionName.SYSDATE.getName(), - impl(() -> new ExprDatetimeValue(formatNow(null)), DATETIME), - impl((v) -> new ExprDatetimeValue(formatNow(v.integerValue())), DATETIME, INTEGER) - ); - } - /** * Time implementation for ExprValue. * @@ -1073,8 +1081,8 @@ private ExprValue exprWeek(ExprValue date, ExprValue mode) { CalendarLookup.getWeekNumber(mode.integerValue(), date.dateValue())); } - private ExprValue unixTimeStamp() { - return new ExprLongValue(Instant.now().getEpochSecond()); + private ExprValue unixTimeStamp(Clock clock) { + return new ExprLongValue(Instant.now(clock).getEpochSecond()); } private ExprValue unixTimeStampOf(ExprValue value) { @@ -1167,17 +1175,18 @@ private ExprValue exprYear(ExprValue date) { return new ExprIntegerValue(date.dateValue().getYear()); } + private LocalDateTime formatNow(Clock clock) { + return formatNow(clock, 0); + } + /** * Prepare LocalDateTime value. Truncate fractional second part according to the argument. * @param fsp argument is given to specify a fractional seconds precision from 0 to 6, * the return value includes a fractional seconds part of that many digits. * @return LocalDateTime object. */ - private LocalDateTime formatNow(@Nullable Integer fsp) { - var res = LocalDateTime.now(); - if (fsp == null) { - fsp = 0; - } + private LocalDateTime formatNow(Clock clock, Integer fsp) { + var res = LocalDateTime.now(clock); var defaultPrecision = 9; // There are 10^9 nanoseconds in one second if (fsp < 0 || fsp > 6) { // Check that the argument is in the allowed range [0, 6] throw new IllegalArgumentException( diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index 33f652d534..08692481cd 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -10,16 +10,11 @@ import com.google.common.collect.ImmutableList; import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; import java.util.HashMap; -import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Set; -import java.util.TreeSet; import java.util.stream.Collectors; +import lombok.Getter; import lombok.RequiredArgsConstructor; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.common.utils.StringUtils; @@ -37,11 +32,13 @@ */ @RequiredArgsConstructor public class BuiltinFunctionRepository { - public static final String DEFAULT_NAMESPACE = "default"; private final Map> namespaceFunctionResolverMap; + @Getter + private final FunctionProperties functionProperties; + /** * Register {@link DefaultFunctionResolver} to the Builtin Function Repository @@ -69,6 +66,11 @@ public void register(String namespace, FunctionResolver resolver) { } + public FunctionImplementation compile(BuiltinFunctionName functionName, + List expressions) { + return compile(functionName.getName(), expressions); + } + /** * Compile FunctionExpression under default namespace. * @@ -90,8 +92,8 @@ public FunctionImplementation compile(String namespace, FunctionName functionNam } FunctionBuilder resolvedFunctionBuilder = resolve(namespaceList, new FunctionSignature(functionName, expressions - .stream().map(expression -> expression.type()).collect(Collectors.toList()))); - return resolvedFunctionBuilder.apply(expressions); + .stream().map(Expression::type).collect(Collectors.toList()))); + return resolvedFunctionBuilder.apply(functionProperties, expressions); } /** @@ -148,7 +150,7 @@ private FunctionBuilder getFunctionBuilder(FunctionSignature functionSignature, private FunctionBuilder castArguments(List sourceTypes, List targetTypes, FunctionBuilder funcBuilder) { - return arguments -> { + return (functionProperties, arguments) -> { List argsCasted = new ArrayList<>(); for (int i = 0; i < arguments.size(); i++) { Expression arg = arguments.get(i); @@ -161,7 +163,7 @@ private FunctionBuilder castArguments(List sourceTypes, argsCasted.add(arg); } } - return funcBuilder.apply(argsCasted); + return funcBuilder.apply(functionProperties, argsCasted); }; } @@ -182,5 +184,4 @@ private Expression cast(Expression arg, ExprType targetType) { } return (Expression) compile(castFunctionName, ImmutableList.of(arg)); } - } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java index aa3077051b..02766c2076 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java @@ -21,5 +21,5 @@ public interface FunctionBuilder { * @param arguments {@link Expression} list * @return {@link FunctionImplementation} */ - FunctionImplementation apply(List arguments); + FunctionImplementation apply(FunctionProperties functionProperties, List arguments); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java index 1fad333ead..472025a44a 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java @@ -19,6 +19,7 @@ import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.DefaultFunctionResolver.DefaultFunctionResolverBuilder; /** * Function Define Utility. @@ -35,7 +36,7 @@ public class FunctionDSL { public static DefaultFunctionResolver define(FunctionName functionName, SerializableFunction>... functions) { - return define(functionName, Arrays.asList(functions)); + return define(functionName, List.of(functions)); } /** @@ -48,8 +49,7 @@ public static DefaultFunctionResolver define(FunctionName functionName, public static DefaultFunctionResolver define(FunctionName functionName, List< SerializableFunction>> functions) { - DefaultFunctionResolver.DefaultFunctionResolverBuilder builder - = DefaultFunctionResolver.builder(); + DefaultFunctionResolverBuilder builder = DefaultFunctionResolver.builder(); builder.functionName(functionName); for (Function> func : functions) { Pair functionBuilder = func.apply(functionName); @@ -58,51 +58,54 @@ public static DefaultFunctionResolver define(FunctionName functionName, List< return builder.build(); } + /** - * No Arg Function Implementation. + * Implementation of no args function that uses FunctionProperties. * - * @param function {@link ExprValue} based unary function. - * @param returnType return type. - * @return Unary Function Implementation. + * @param function {@link ExprValue} based no args function. + * @param returnType function return type. + * @return no args function implementation. */ - public static SerializableFunction> impl( - SerializableNoArgFunction function, - ExprType returnType) { - + public static SerializableFunction> + implWithProperties(SerializableFunction function, + ExprType returnType) { return functionName -> { FunctionSignature functionSignature = new FunctionSignature(functionName, Collections.emptyList()); FunctionBuilder functionBuilder = - arguments -> new FunctionExpression(functionName, Collections.emptyList()) { - @Override - public ExprValue valueOf(Environment valueEnv) { - return function.get(); - } + (functionProperties, arguments) -> + new FunctionExpression(functionName, Collections.emptyList()) { + @Override + public ExprValue valueOf(Environment valueEnv) { + return function.apply(functionProperties); + } - @Override - public ExprType type() { - return returnType; - } + @Override + public ExprType type() { + return returnType; + } - @Override - public String toString() { - return String.format("%s()", functionName); - } - }; + @Override + public String toString() { + return String.format("%s()", functionName); + } + }; return Pair.of(functionSignature, functionBuilder); }; } /** - * Unary Function Implementation. + * Implementation of a function that takes one argument, returns a value, and + * requires FunctionProperties to complete. * * @param function {@link ExprValue} based unary function. * @param returnType return type. - * @param argsType argument type. + * @param argsType argument type. * @return Unary Function Implementation. */ - public static SerializableFunction> impl( - SerializableFunction function, + public static SerializableFunction> + implWithProperties( + SerializableBiFunction function, ExprType returnType, ExprType argsType) { @@ -110,11 +113,11 @@ public static SerializableFunction new FunctionExpression(functionName, arguments) { + (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { @Override public ExprValue valueOf(Environment valueEnv) { ExprValue value = arguments.get(0).valueOf(valueEnv); - return function.apply(value); + return function.apply(functionProperties, value); } @Override @@ -134,6 +137,35 @@ public String toString() { }; } + /** + * No Arg Function Implementation. + * + * @param function {@link ExprValue} based unary function. + * @param returnType return type. + * @return Unary Function Implementation. + */ + public static SerializableFunction> impl( + SerializableNoArgFunction function, + ExprType returnType) { + return implWithProperties(fp -> function.get(), returnType); + } + + /** + * Unary Function Implementation. + * + * @param function {@link ExprValue} based unary function. + * @param returnType return type. + * @param argsType argument type. + * @return Unary Function Implementation. + */ + public static SerializableFunction> impl( + SerializableFunction function, + ExprType returnType, + ExprType argsType) { + + return implWithProperties((fp, arg) -> function.apply(arg), returnType, argsType); + } + /** * Binary Function Implementation. * @@ -153,7 +185,7 @@ public static SerializableFunction new FunctionExpression(functionName, arguments) { + (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { @Override public ExprValue valueOf(Environment valueEnv) { ExprValue arg1 = arguments.get(0).valueOf(valueEnv); @@ -196,7 +228,7 @@ public static SerializableFunction new FunctionExpression(functionName, arguments) { + (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { @Override public ExprValue valueOf(Environment valueEnv) { ExprValue arg1 = arguments.get(0).valueOf(valueEnv); @@ -267,4 +299,22 @@ public SerializableTriFunction nullM } }; } + + /** + * Wrapper the unary ExprValue function that is aware of FunctionProperties, + * with default NULL and MISSING handling. + */ + public static SerializableBiFunction + nullMissingHandlingWithProperties( + SerializableBiFunction implementation) { + return (functionProperties, v1) -> { + if (v1.isMissing()) { + return ExprValueUtils.missingValue(); + } else if (v1.isNull()) { + return ExprValueUtils.nullValue(); + } else { + return implementation.apply(functionProperties, v1); + } + }; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionProperties.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionProperties.java new file mode 100644 index 0000000000..279019973e --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionProperties.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.io.Serializable; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneId; +import lombok.EqualsAndHashCode; +import lombok.RequiredArgsConstructor; + +/** + * Class to capture values that may be necessary to implement some functions. + * An example would be query execution start time to implement now(). + */ +@RequiredArgsConstructor +@EqualsAndHashCode +public class FunctionProperties implements Serializable { + + private final Instant nowInstant; + private final ZoneId currentZoneId; + + + /** + * Method to access current system clock. + * @return a ticking clock that tells the time. + */ + public Clock getSystemClock() { + return Clock.system(currentZoneId); + } + + /** + * Method to get time when query began execution. + * Clock class combines an instant Supplier and a time zone. + * @return a fixed clock that returns the time execution started at. + * + */ + public Clock getQueryStartClock() { + return Clock.fixed(nowInstant, currentZoneId); + } +} + diff --git a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java index e781db8c84..7066622e1b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java @@ -47,8 +47,8 @@ public Pair resolve(FunctionSignature unreso } } - FunctionBuilder buildFunction = - args -> new OpenSearchFunctions.OpenSearchFunction(functionName, args); + FunctionBuilder buildFunction = (functionProperties, args) + -> new OpenSearchFunctions.OpenSearchFunction(functionName, args); return Pair.of(unresolvedSignature, buildFunction); } diff --git a/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java b/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java index 5e955c2e62..6d8dd5093e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java @@ -38,7 +38,8 @@ private static FunctionResolver typeof() { public Pair resolve( FunctionSignature unresolvedSignature) { return Pair.of(unresolvedSignature, - arguments -> new FunctionExpression(BuiltinFunctionName.TYPEOF.getName(), arguments) { + (functionProperties, arguments) -> + new FunctionExpression(BuiltinFunctionName.TYPEOF.getName(), arguments) { @Override public ExprValue valueOf(Environment valueEnv) { return new ExprStringValue(getArguments().get(0).type().toString()); diff --git a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java index a3baf08ff3..9a9e0c4c86 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java @@ -54,9 +54,8 @@ private DefaultFunctionResolver denseRank() { private DefaultFunctionResolver rankingFunction(FunctionName functionName, Supplier constructor) { FunctionSignature functionSignature = new FunctionSignature(functionName, emptyList()); - FunctionBuilder functionBuilder = arguments -> constructor.get(); + FunctionBuilder functionBuilder = (functionProperties, arguments) -> constructor.get(); return new DefaultFunctionResolver(functionName, ImmutableMap.of(functionSignature, functionBuilder)); } - } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index 3523fed98e..dc97d602fa 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -152,7 +152,8 @@ public Pair resolve( FunctionSignature functionSignature = new FunctionSignature(functionName, List.of(STRING, LONG, LONG, LONG)); return Pair.of(functionSignature, - args -> new TestTableFunctionImplementation(functionName, args, table)); + (functionProperties, args) -> new TestTableFunctionImplementation(functionName, args, + table)); } @Override diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 98176f0002..1ef884f320 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -565,27 +565,12 @@ public void match_phrase_prefix_all_params() { ); } - @Test - public void constant_function_is_calculated_on_analyze() { - // Actually, we can call any function as ConstantFunction to be calculated on analyze stage - assertTrue(analyze(AstDSL.constantFunction("now")) instanceof LiteralExpression); - assertTrue(analyze(AstDSL.constantFunction("localtime")) instanceof LiteralExpression); - } - @Test public void function_isnt_calculated_on_analyze() { assertTrue(analyze(function("now")) instanceof FunctionExpression); assertTrue(analyze(AstDSL.function("localtime")) instanceof FunctionExpression); } - @Test - public void constant_function_returns_constant_cached_value() { - var values = List.of(analyze(AstDSL.constantFunction("now")), - analyze(AstDSL.constantFunction("now")), analyze(AstDSL.constantFunction("now"))); - assertTrue(values.stream().allMatch(v -> - v.valueOf(null) == analyze(AstDSL.constantFunction("now")).valueOf(null))); - } - @Test public void function_returns_non_constant_value() { // Even a function returns the same values - they are calculated on each call diff --git a/core/src/test/java/org/opensearch/sql/expression/ExpressionNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/expression/ExpressionNodeVisitorTest.java index b1b22bedb1..e089ae376f 100644 --- a/core/src/test/java/org/opensearch/sql/expression/ExpressionNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/ExpressionNodeVisitorTest.java @@ -20,17 +20,25 @@ import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.opensearch.sql.analysis.AnalyzerTestBase; import org.opensearch.sql.expression.aggregation.Aggregator; import org.opensearch.sql.expression.aggregation.AvgAggregator; import org.opensearch.sql.expression.conditional.cases.CaseClause; import org.opensearch.sql.expression.conditional.cases.WhenClause; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.parse.ParseExpression; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +@ExtendWith(SpringExtension.class) +@ContextConfiguration(classes = {ExpressionConfig.class, AnalyzerTestBase.class}) class ExpressionNodeVisitorTest { - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + @Autowired + private DSL dsl; @Test void should_return_null_by_default() { diff --git a/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java b/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java index fea985042a..dd28ea8975 100644 --- a/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java +++ b/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java @@ -36,6 +36,7 @@ import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -50,6 +51,9 @@ public class ExpressionTestBase { @Autowired protected DSL dsl; + @Autowired + protected FunctionProperties functionProperties; + @Autowired protected Environment typeEnv; diff --git a/core/src/test/java/org/opensearch/sql/expression/config/ExpressionConfigTest.java b/core/src/test/java/org/opensearch/sql/expression/config/ExpressionConfigTest.java new file mode 100644 index 0000000000..d36c90599e --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/config/ExpressionConfigTest.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.config; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +import org.junit.jupiter.api.Test; +import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.FunctionProperties; +import org.springframework.context.annotation.AnnotationConfigApplicationContext; + +class ExpressionConfigTest { + private static AnnotationConfigApplicationContext createContext() { + var context = new AnnotationConfigApplicationContext(); + context.register(ExpressionConfig.class); + context.refresh(); + return context; + } + + @Test + void testContextIsFromBean() { + AnnotationConfigApplicationContext context = createContext(); + BuiltinFunctionRepository repository = context.getBean(BuiltinFunctionRepository.class); + assertEquals(repository.getFunctionProperties(), + context.getBean(FunctionProperties.class)); + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java b/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java index 91d1fc4b0f..7f3c26276b 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java @@ -5,21 +5,18 @@ package org.opensearch.sql.expression.datetime; -import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; -import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; -import static org.opensearch.sql.expression.function.BuiltinFunctionRepository.DEFAULT_NAMESPACE; - import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; -import java.util.Collections; import java.util.List; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprDatetimeValue; +import org.opensearch.sql.data.model.ExprMissingValue; +import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.DSL; @@ -27,9 +24,8 @@ import org.opensearch.sql.expression.ExpressionTestBase; import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; -import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionSignature; import org.springframework.beans.factory.annotation.Autowired; @ExtendWith(MockitoExtension.class) @@ -38,11 +34,9 @@ public class DateTimeTestBase extends ExpressionTestBase { @Mock protected Environment env; - @Mock - protected Expression nullRef; + protected Expression nullRef = DSL.literal(ExprNullValue.of()); - @Mock - protected Expression missingRef; + protected Expression missingRef = DSL.literal(ExprMissingValue.of()); @Autowired protected BuiltinFunctionRepository functionRepository; @@ -52,17 +46,13 @@ protected ExprValue eval(Expression expression) { } protected FunctionExpression fromUnixTime(Expression value) { - var func = functionRepository.resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("from_unixtime"), - List.of(value.type()))); - return (FunctionExpression)func.apply(List.of(value)); + return (FunctionExpression) + functionRepository.compile(BuiltinFunctionName.FROM_UNIXTIME, List.of(value)); } protected FunctionExpression fromUnixTime(Expression value, Expression format) { - var func = functionRepository.resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("from_unixtime"), - List.of(value.type(), format.type()))); - return (FunctionExpression)func.apply(List.of(value, format)); + return (FunctionExpression) + functionRepository.compile(BuiltinFunctionName.FROM_UNIXTIME, List.of(value, format)); } protected LocalDateTime fromUnixTime(Long value) { @@ -74,29 +64,18 @@ protected LocalDateTime fromUnixTime(Double value) { } protected String fromUnixTime(Long value, String format) { - return fromUnixTime(DSL.literal(value), DSL.literal(format)).valueOf(null).stringValue(); + return fromUnixTime(DSL.literal(value), DSL.literal(format)) + .valueOf(null).stringValue(); } protected String fromUnixTime(Double value, String format) { - return fromUnixTime(DSL.literal(value), DSL.literal(format)).valueOf(null).stringValue(); - } - - protected FunctionExpression makedate(Expression year, Expression dayOfYear) { - var func = functionRepository.resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("makedate"), - List.of(DOUBLE, DOUBLE))); - return (FunctionExpression)func.apply(List.of(year, dayOfYear)); - } - - protected LocalDate makedate(Double year, Double dayOfYear) { - return makedate(DSL.literal(year), DSL.literal(dayOfYear)).valueOf(null).dateValue(); + return fromUnixTime(DSL.literal(value), DSL.literal(format)) + .valueOf(null).stringValue(); } protected FunctionExpression maketime(Expression hour, Expression minute, Expression second) { - var func = functionRepository.resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("maketime"), - List.of(DOUBLE, DOUBLE, DOUBLE))); - return (FunctionExpression)func.apply(List.of(hour, minute, second)); + return (FunctionExpression) + functionRepository.compile(BuiltinFunctionName.MAKETIME, List.of(hour, minute, second)); } protected LocalTime maketime(Double hour, Double minute, Double second) { @@ -104,32 +83,38 @@ protected LocalTime maketime(Double hour, Double minute, Double second) { .valueOf(null).timeValue(); } + protected FunctionExpression makedate(Expression year, Expression dayOfYear) { + return (FunctionExpression) + functionRepository.compile(BuiltinFunctionName.MAKEDATE, List.of(year, dayOfYear)); + } + + protected LocalDate makedate(double year, double dayOfYear) { + return makedate(DSL.literal(year), DSL.literal(dayOfYear)).valueOf(null).dateValue(); + } + protected FunctionExpression period_add(Expression period, Expression months) { - var func = functionRepository.resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("period_add"), - List.of(INTEGER, INTEGER))); - return (FunctionExpression)func.apply(List.of(period, months)); + return (FunctionExpression) + functionRepository.compile(BuiltinFunctionName.PERIOD_ADD, List.of(period, months)); } protected Integer period_add(Integer period, Integer months) { - return period_add(DSL.literal(period), DSL.literal(months)).valueOf(null).integerValue(); + return period_add(DSL.literal(period), DSL.literal(months)) + .valueOf(null).integerValue(); } protected FunctionExpression period_diff(Expression first, Expression second) { - var func = functionRepository.resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("period_diff"), - List.of(INTEGER, INTEGER))); - return (FunctionExpression)func.apply(List.of(first, second)); + return (FunctionExpression) + functionRepository.compile(BuiltinFunctionName.PERIOD_DIFF, List.of(first, second)); } protected Integer period_diff(Integer first, Integer second) { - return period_diff(DSL.literal(first), DSL.literal(second)).valueOf(null).integerValue(); + return period_diff(DSL.literal(first), DSL.literal(second)) + .valueOf(null).integerValue(); } protected FunctionExpression unixTimeStampExpr() { - var func = functionRepository.resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("unix_timestamp"), List.of())); - return (FunctionExpression)func.apply(List.of()); + return (FunctionExpression) + functionRepository.compile(BuiltinFunctionName.UNIX_TIMESTAMP, List.of()); } protected Long unixTimeStamp() { @@ -137,10 +122,8 @@ protected Long unixTimeStamp() { } protected FunctionExpression unixTimeStampOf(Expression value) { - var func = functionRepository.resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("unix_timestamp"), - List.of(value.type()))); - return (FunctionExpression)func.apply(List.of(value)); + return (FunctionExpression) + functionRepository.compile(BuiltinFunctionName.UNIX_TIMESTAMP, List.of(value)); } protected Double unixTimeStampOf(Double value) { @@ -148,14 +131,17 @@ protected Double unixTimeStampOf(Double value) { } protected Double unixTimeStampOf(LocalDate value) { - return unixTimeStampOf(DSL.literal(new ExprDateValue(value))).valueOf(null).doubleValue(); + return unixTimeStampOf(DSL.literal(new ExprDateValue(value))) + .valueOf(null).doubleValue(); } protected Double unixTimeStampOf(LocalDateTime value) { - return unixTimeStampOf(DSL.literal(new ExprDatetimeValue(value))).valueOf(null).doubleValue(); + return unixTimeStampOf(DSL.literal(new ExprDatetimeValue(value))) + .valueOf(null).doubleValue(); } protected Double unixTimeStampOf(Instant value) { - return unixTimeStampOf(DSL.literal(new ExprTimestampValue(value))).valueOf(null).doubleValue(); + return unixTimeStampOf(DSL.literal(new ExprTimestampValue(value))) + .valueOf(null).doubleValue(); } } diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/MakeDateTest.java b/core/src/test/java/org/opensearch/sql/expression/datetime/MakeDateTest.java index 981468f31e..7772c1b77f 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/MakeDateTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/MakeDateTest.java @@ -60,8 +60,6 @@ public void checkNullValues() { @Test public void checkMissingValues() { - when(missingRef.valueOf(env)).thenReturn(missingValue()); - assertEquals(missingValue(), eval(makedate(missingRef, DSL.literal(42.)))); assertEquals(missingValue(), eval(makedate(DSL.literal(42.), missingRef))); assertEquals(missingValue(), eval(makedate(missingRef, missingRef))); diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/MakeTimeTest.java b/core/src/test/java/org/opensearch/sql/expression/datetime/MakeTimeTest.java index 7dd78ff845..3fb2472c18 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/MakeTimeTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/MakeTimeTest.java @@ -58,8 +58,6 @@ public void checkSecondFraction() { @Test public void checkNullValues() { - when(nullRef.valueOf(env)).thenReturn(nullValue()); - assertEquals(nullValue(), eval(maketime(nullRef, DSL.literal(42.), DSL.literal(42.)))); assertEquals(nullValue(), eval(maketime(DSL.literal(42.), nullRef, DSL.literal(42.)))); assertEquals(nullValue(), eval(maketime(DSL.literal(42.), DSL.literal(42.), nullRef))); @@ -71,8 +69,6 @@ public void checkNullValues() { @Test public void checkMissingValues() { - when(missingRef.valueOf(env)).thenReturn(missingValue()); - assertEquals(missingValue(), eval(maketime(missingRef, DSL.literal(42.), DSL.literal(42.)))); assertEquals(missingValue(), eval(maketime(DSL.literal(42.), missingRef, DSL.literal(42.)))); assertEquals(missingValue(), eval(maketime(DSL.literal(42.), DSL.literal(42.), missingRef))); diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/NowLikeFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/datetime/NowLikeFunctionTest.java index e8f5c16025..0fa2994ecc 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/NowLikeFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/NowLikeFunctionTest.java @@ -20,92 +20,109 @@ import java.time.Period; import java.time.temporal.Temporal; import java.util.List; +import java.util.concurrent.Callable; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Stream; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; +import org.junit.jupiter.api.DynamicTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestFactory; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionTestBase; import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.config.ExpressionConfig; -public class NowLikeFunctionTest extends ExpressionTestBase { - private static Stream functionNames() { - var dsl = new DSL(new ExpressionConfig().functionRepository()); - return Stream.of( - Arguments.of((Function)dsl::now, - "now", DATETIME, false, (Supplier)LocalDateTime::now), - Arguments.of((Function)dsl::current_timestamp, - "current_timestamp", DATETIME, false, (Supplier)LocalDateTime::now), - Arguments.of((Function)dsl::localtimestamp, - "localtimestamp", DATETIME, false, (Supplier)LocalDateTime::now), - Arguments.of((Function)dsl::localtime, - "localtime", DATETIME, false, (Supplier)LocalDateTime::now), - Arguments.of((Function)dsl::sysdate, - "sysdate", DATETIME, true, (Supplier)LocalDateTime::now), - Arguments.of((Function)dsl::curtime, - "curtime", TIME, false, (Supplier)LocalTime::now), - Arguments.of((Function)dsl::current_time, - "current_time", TIME, false, (Supplier)LocalTime::now), - Arguments.of((Function)dsl::curdate, - "curdate", DATE, false, (Supplier)LocalDate::now), - Arguments.of((Function)dsl::current_date, - "current_date", DATE, false, (Supplier)LocalDate::now)); +class NowLikeFunctionTest extends ExpressionTestBase { + @Test + void now() { + test_now_like_functions(dsl::now, + DATETIME, + false, + () -> LocalDateTime.now(functionProperties.getQueryStartClock())); } - private Temporal extractValue(FunctionExpression func) { - switch ((ExprCoreType)func.type()) { - case DATE: return func.valueOf(null).dateValue(); - case DATETIME: return func.valueOf(null).datetimeValue(); - case TIME: return func.valueOf(null).timeValue(); - // unreachable code - default: throw new IllegalArgumentException(String.format("%s", func.type())); - } + @Test + void current_timestamp() { + test_now_like_functions(dsl::current_timestamp, DATETIME, false, + () -> LocalDateTime.now(functionProperties.getQueryStartClock())); } - private long getDiff(Temporal sample, Temporal reference) { - if (sample instanceof LocalDate) { - return Period.between((LocalDate) sample, (LocalDate) reference).getDays(); - } - return Duration.between(sample, reference).toSeconds(); + @Test + void localtimestamp() { + test_now_like_functions(dsl::localtimestamp, DATETIME, false, + () -> LocalDateTime.now(functionProperties.getQueryStartClock())); + } + + @Test + void localtime() { + test_now_like_functions(dsl::localtime, DATETIME, false, + () -> LocalDateTime.now(functionProperties.getQueryStartClock())); + } + + @Test + void sysdate() { + test_now_like_functions(dsl::sysdate, DATETIME, true, LocalDateTime::now); + } + + @Test + void curtime() { + test_now_like_functions(dsl::curtime, TIME, false, + () -> LocalTime.now(functionProperties.getQueryStartClock())); + } + + @Test + void currdate() { + + test_now_like_functions(dsl::curdate, + DATE, false, + () -> LocalDate.now(functionProperties.getQueryStartClock())); + } + + @Test + void current_time() { + test_now_like_functions(dsl::current_time, + TIME, + false, + () -> LocalTime.now(functionProperties.getQueryStartClock())); + } + + @Test + void current_date() { + test_now_like_functions(dsl::current_date, DATE, false, + () -> LocalDate.now(functionProperties.getQueryStartClock())); } /** * Check how NOW-like functions are processed. - * @param function Function - * @param name Function name - * @param resType Return type - * @param hasFsp Whether function has fsp argument + * + * @param function Function + * @param resType Return type + * @param hasFsp Whether function has fsp argument * @param referenceGetter A callback to get reference value */ - @ParameterizedTest(name = "{1}") - @MethodSource("functionNames") - public void test_now_like_functions(Function function, - @SuppressWarnings("unused") // Used in the test name above - String name, - ExprCoreType resType, - Boolean hasFsp, - Supplier referenceGetter) { + void test_now_like_functions(Function function, + ExprCoreType resType, + Boolean hasFsp, + Supplier referenceGetter) { // Check return types: // `func()` - FunctionExpression expr = function.apply(new Expression[]{}); + FunctionExpression expr = function.apply(new Expression[] {}); assertEquals(resType, expr.type()); if (hasFsp) { // `func(fsp = 0)` - expr = function.apply(new Expression[]{DSL.literal(0)}); + expr = function.apply(new Expression[] {DSL.literal(0)}); assertEquals(resType, expr.type()); // `func(fsp = 6)` - expr = function.apply(new Expression[]{DSL.literal(6)}); + expr = function.apply(new Expression[] {DSL.literal(6)}); assertEquals(resType, expr.type()); - for (var wrongFspValue: List.of(-1, 10)) { + for (var wrongFspValue : List.of(-1, 10)) { var exception = assertThrows(IllegalArgumentException.class, - () -> function.apply(new Expression[]{DSL.literal(wrongFspValue)}).valueOf(null)); + () -> function.apply( + new Expression[] {DSL.literal(wrongFspValue)}).valueOf(null)); assertEquals(String.format("Invalid `fsp` value: %d, allowed 0 to 6", wrongFspValue), exception.getMessage()); } @@ -113,16 +130,67 @@ public void test_now_like_functions(Function f // Check how calculations are precise: // `func()` - assertTrue(Math.abs(getDiff( - extractValue(function.apply(new Expression[]{})), - referenceGetter.get() - )) <= 1); + Temporal sample = extractValue(function.apply(new Expression[] {})); + Temporal reference = referenceGetter.get(); + assertTrue(Math.abs(getDiff(reference, sample)) <= 1); if (hasFsp) { // `func(fsp)` assertTrue(Math.abs(getDiff( - extractValue(function.apply(new Expression[]{DSL.literal(0)})), - referenceGetter.get() + extractValue(function.apply(new Expression[] {DSL.literal(0)})), + referenceGetter.get() )) <= 1); } } + + @TestFactory + Stream constantValueTestFactory() { + BiFunction, DynamicTest> buildTest = (name, action) -> + DynamicTest.dynamicTest( + String.format("multiple_invocations_same_value_test[%s]", name), + () -> { + var v1 = extractValue(action.call()); + Thread.sleep(1000); + var v2 = extractValue(action.call()); + assertEquals(v1, v2); + } + ); + return Stream.of( + buildTest.apply("now", dsl::now), + buildTest.apply("current_timestamp", dsl::current_timestamp), + buildTest.apply("current_time", dsl::current_time), + buildTest.apply("curdate", dsl::curdate), + buildTest.apply("curtime", dsl::curtime), + buildTest.apply("localtimestamp", dsl::localtimestamp), + buildTest.apply("localtime", dsl::localtime) + ); + } + + @Test + void sysdate_multiple_invocations_differ() throws InterruptedException { + var v1 = extractValue(dsl.sysdate()); + Thread.sleep(1000); + var v2 = extractValue(dsl.sysdate()); + assertEquals(1, getDiff(v1, v2)); + } + + private Temporal extractValue(FunctionExpression func) { + switch ((ExprCoreType) func.type()) { + case DATE: + return func.valueOf(null).dateValue(); + case DATETIME: + return func.valueOf(null).datetimeValue(); + case TIME: + return func.valueOf(null).timeValue(); + // unreachable code + default: + throw new IllegalArgumentException(String.format("%s", func.type())); + } + } + + private long getDiff(Temporal sample, Temporal reference) { + if (sample instanceof LocalDate) { + return Period.between((LocalDate) sample, (LocalDate) reference).getDays(); + } + return Duration.between(sample, reference).toSeconds(); + } } diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTimeStampTest.java b/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTimeStampTest.java index 437e195f3e..dbc8468fdf 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTimeStampTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTimeStampTest.java @@ -30,8 +30,10 @@ public class UnixTimeStampTest extends DateTimeTestBase { @Test public void checkNoArgs() { - assertEquals(System.currentTimeMillis() / 1000L, unixTimeStamp()); - assertEquals(System.currentTimeMillis() / 1000L, eval(unixTimeStampExpr()).longValue()); + + final long expected = functionProperties.getQueryStartClock().millis() / 1000L; + assertEquals(expected, unixTimeStamp()); + assertEquals(expected, eval(unixTimeStampExpr()).longValue()); } private static Stream getDateSamples() { diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTwoWayConversionTest.java b/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTwoWayConversionTest.java index dc509d175b..cb4d318eb3 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTwoWayConversionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTwoWayConversionTest.java @@ -8,38 +8,34 @@ import static org.junit.jupiter.api.Assertions.assertEquals; -import java.time.Instant; -import java.time.LocalDate; import java.time.LocalDateTime; import java.time.ZoneId; -import java.util.List; import java.util.stream.Stream; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; -import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprDatetimeValue; import org.opensearch.sql.data.model.ExprDoubleValue; import org.opensearch.sql.data.model.ExprLongValue; -import org.opensearch.sql.data.model.ExprTimestampValue; -import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.ExpressionTestBase; -import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.config.ExpressionConfig; -import org.opensearch.sql.expression.env.Environment; -import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionSignature; public class UnixTwoWayConversionTest extends DateTimeTestBase { @Test public void checkConvertNow() { - assertEquals(LocalDateTime.now(ZoneId.of("UTC")).withNano(0), fromUnixTime(unixTimeStamp())); - assertEquals(LocalDateTime.now(ZoneId.of("UTC")).withNano(0), - eval(fromUnixTime(unixTimeStampExpr())).datetimeValue()); + assertEquals(getExpectedNow(), fromUnixTime(unixTimeStamp())); + } + + @Test + public void checkConvertNow_with_eval() { + assertEquals(getExpectedNow(), eval(fromUnixTime(unixTimeStampExpr())).datetimeValue()); + } + + private LocalDateTime getExpectedNow() { + return LocalDateTime.now( + functionProperties.getQueryStartClock().withZone(ZoneId.of("UTC"))) + .withNano(0); } private static Stream getDoubleSamples() { diff --git a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java index 5dd98dfedf..3d0437f737 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertThrows; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; import static org.mockito.Mockito.lenient; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -58,6 +59,8 @@ class BuiltinFunctionRepositoryTest { @Mock private Map mockMap; @Mock + FunctionProperties functionProperties; + @Mock private FunctionName mockFunctionName; @Mock private FunctionBuilder functionExpressionBuilder; @@ -72,7 +75,7 @@ class BuiltinFunctionRepositoryTest { @BeforeEach void setUp() { - repo = new BuiltinFunctionRepository(mockNamespaceMap); + repo = new BuiltinFunctionRepository(mockNamespaceMap, functionProperties); } @Test @@ -80,7 +83,8 @@ void register() { when(mockNamespaceMap.get(DEFAULT_NAMESPACE)).thenReturn(mockMap); when(mockNamespaceMap.containsKey(DEFAULT_NAMESPACE)).thenReturn(true); when(mockfunctionResolver.getFunctionName()).thenReturn(mockFunctionName); - BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap); + BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap, + functionProperties); repo.register(mockfunctionResolver); verify(mockMap, times(1)).put(mockFunctionName, mockfunctionResolver); @@ -92,7 +96,8 @@ void register_under_catalog_namespace() { when(mockNamespaceMap.put(eq(TEST_NAMESPACE), any())).thenReturn(null); when(mockNamespaceMap.get(TEST_NAMESPACE)).thenReturn(mockMap); when(mockfunctionResolver.getFunctionName()).thenReturn(mockFunctionName); - BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap); + BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap, + functionProperties); repo.register(TEST_NAMESPACE, mockfunctionResolver); verify(mockNamespaceMap, times(1)).put(eq(TEST_NAMESPACE), any()); @@ -111,11 +116,12 @@ void compile() { when(mockNamespaceMap.containsKey(DEFAULT_NAMESPACE)).thenReturn(true); when(mockMap.containsKey(mockFunctionName)).thenReturn(true); when(mockMap.get(mockFunctionName)).thenReturn(mockfunctionResolver); - BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap); + BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap, + functionProperties); repo.register(mockfunctionResolver); repo.compile(mockFunctionName, Arrays.asList(mockExpression)); - verify(functionExpressionBuilder, times(1)).apply(any()); + verify(functionExpressionBuilder, times(1)).apply(same(functionProperties), any()); } @@ -130,11 +136,12 @@ void compile_function_under_catalog_namespace() { when(mockNamespaceMap.containsKey(TEST_NAMESPACE)).thenReturn(true); when(mockMap.containsKey(mockFunctionName)).thenReturn(true); when(mockMap.get(mockFunctionName)).thenReturn(mockfunctionResolver); - BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap); + BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap, + functionProperties); repo.register(TEST_NAMESPACE, mockfunctionResolver); repo.compile(TEST_NAMESPACE, mockFunctionName, Arrays.asList(mockExpression)); - verify(functionExpressionBuilder, times(1)).apply(any()); + verify(functionExpressionBuilder, times(1)).apply(same(functionProperties), any()); } @Test @@ -148,7 +155,8 @@ void resolve() { when(mockNamespaceMap.containsKey(DEFAULT_NAMESPACE)).thenReturn(true); when(mockMap.containsKey(mockFunctionName)).thenReturn(true); when(mockMap.get(mockFunctionName)).thenReturn(mockfunctionResolver); - BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap); + BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap, + functionProperties); repo.register(mockfunctionResolver); assertEquals(functionExpressionBuilder, @@ -161,7 +169,7 @@ void resolve_should_not_cast_arguments_in_cast_function() { FunctionImplementation function = repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), registerFunctionResolver(CAST_TO_BOOLEAN.getName(), DATETIME, BOOLEAN)) - .apply(ImmutableList.of(mockExpression)); + .apply(functionProperties, ImmutableList.of(mockExpression)); assertEquals("cast_to_boolean(string)", function.toString()); } @@ -172,7 +180,7 @@ void resolve_should_not_cast_arguments_if_same_type() { FunctionImplementation function = repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), registerFunctionResolver(mockFunctionName, STRING, STRING)) - .apply(ImmutableList.of(mockExpression)); + .apply(functionProperties, ImmutableList.of(mockExpression)); assertEquals("mock(string)", function.toString()); } @@ -183,7 +191,7 @@ void resolve_should_not_cast_arguments_if_both_numbers() { FunctionImplementation function = repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), registerFunctionResolver(mockFunctionName, BYTE, INTEGER)) - .apply(ImmutableList.of(mockExpression)); + .apply(functionProperties, ImmutableList.of(mockExpression)); assertEquals("mock(byte)", function.toString()); } @@ -199,7 +207,7 @@ void resolve_should_cast_arguments() { FunctionImplementation function = repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), signature) - .apply(ImmutableList.of(mockExpression)); + .apply(functionProperties, ImmutableList.of(mockExpression)); assertEquals("mock(cast_to_boolean(string))", function.toString()); } @@ -209,7 +217,7 @@ void resolve_should_throw_exception_for_unsupported_conversion() { assertThrows(ExpressionEvaluationException.class, () -> repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), registerFunctionResolver(mockFunctionName, BYTE, STRUCT)) - .apply(ImmutableList.of(mockExpression))); + .apply(functionProperties, ImmutableList.of(mockExpression))); assertEquals(error.getMessage(), "Type conversion to type STRUCT is not supported"); } @@ -219,7 +227,8 @@ void resolve_unregistered() { when(mockNamespaceMap.get(DEFAULT_NAMESPACE)).thenReturn(mockMap); when(mockNamespaceMap.containsKey(DEFAULT_NAMESPACE)).thenReturn(true); when(mockMap.containsKey(any())).thenReturn(false); - BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap); + BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap, + functionProperties); repo.register(mockfunctionResolver); ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, @@ -249,8 +258,8 @@ private FunctionSignature registerFunctionResolver(FunctionName funcName, // Relax unnecessary stubbing check because error case test doesn't call this lenient().doAnswer(invocation -> - new FakeFunctionExpression(funcName, invocation.getArgument(0)) - ).when(funcBuilder).apply(any()); + new FakeFunctionExpression(funcName, invocation.getArgument(1)) + ).when(funcBuilder).apply(same(functionProperties), any()); return unresolvedSignature; } diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLDefineTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLDefineTest.java new file mode 100644 index 0000000000..8bf4d7ba24 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLDefineTest.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.sql.expression.function.FunctionDSL.define; + +import java.util.List; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.expression.Expression; + +@ExtendWith(MockitoExtension.class) +class FunctionDSLDefineTest extends FunctionDSLTestBase { + + @Test + void define_variableArgs_test() { + Pair resolvedA = + Pair.of(SAMPLE_SIGNATURE_A, new SampleFunctionBuilder()); + + FunctionResolver resolver = define(SAMPLE_NAME, v -> resolvedA); + + assertEquals(resolvedA, resolver.resolve(SAMPLE_SIGNATURE_A)); + } + + @Test + void define_test() { + Pair resolved = + Pair.of(SAMPLE_SIGNATURE_A, new SampleFunctionBuilder()); + + FunctionResolver resolver = define(SAMPLE_NAME, List.of(v -> resolved)); + + assertEquals(resolved, resolver.resolve(SAMPLE_SIGNATURE_A)); + } + + @Test + void define_name_test() { + Pair resolved = + Pair.of(SAMPLE_SIGNATURE_A, new SampleFunctionBuilder()); + + FunctionResolver resolver = define(SAMPLE_NAME, List.of(v -> resolved)); + + assertEquals(SAMPLE_NAME, resolver.getFunctionName()); + } + + static class SampleFunctionBuilder implements FunctionBuilder { + @Override + public FunctionImplementation apply(FunctionProperties functionProperties, + List arguments) { + return new SampleFunctionImplementation(arguments); + } + } + + @RequiredArgsConstructor + static class SampleFunctionImplementation implements FunctionImplementation { + @Getter + private final List arguments; + + @Override + public FunctionName getFunctionName() { + return SAMPLE_NAME; + } + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java new file mode 100644 index 0000000000..193066e626 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.util.List; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprMissingValue; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; + +@ExtendWith(MockitoExtension.class) +public class FunctionDSLTestBase { + @Mock + FunctionProperties functionProperties; + + public static final ExprNullValue NULL = ExprNullValue.of(); + public static final ExprMissingValue MISSING = ExprMissingValue.of(); + protected static final ExprType ANY_TYPE = () -> "ANY"; + protected static final ExprValue ANY = new ExprValue() { + @Override + public Object value() { + throw new RuntimeException(); + } + + @Override + public ExprType type() { + return ANY_TYPE; + } + + @Override + public String toString() { + return "ANY"; + } + + @Override + public int compareTo(ExprValue o) { + throw new RuntimeException(); + } + }; + static final FunctionName SAMPLE_NAME = FunctionName.of("sample"); + static final FunctionSignature SAMPLE_SIGNATURE_A = + new FunctionSignature(SAMPLE_NAME, List.of(ExprCoreType.UNDEFINED)); + static final SerializableNoArgFunction noArg = () -> ANY; + static final SerializableFunction oneArg = v -> ANY; + static final SerializableBiFunction + oneArgWithProperties = (functionProperties, v) -> ANY; + + static final SerializableBiFunction + twoArgs = (v1, v2) -> ANY; + static final SerializableTriFunction + threeArgs = (v1, v2, v3) -> ANY; + @Mock + FunctionProperties mockProperties; +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplNoArgTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplNoArgTest.java new file mode 100644 index 0000000000..5d970803ed --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplNoArgTest.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.opensearch.sql.expression.function.FunctionDSL.impl; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeEach; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplNoArgTest extends FunctionDSLimplTestBase { + @Override + SerializableFunction> + getImplementationGenerator() { + return impl(noArg, ANY_TYPE); + } + + @Override + List getSampleArguments() { + return List.of(); + } + + @Override + String getExpected_toString() { + return "sample()"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplOneArgTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplOneArgTest.java new file mode 100644 index 0000000000..6e7c194f5d --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplOneArgTest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.opensearch.sql.expression.function.FunctionDSL.impl; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplOneArgTest extends FunctionDSLimplTestBase { + + @Override + SerializableFunction> + getImplementationGenerator() { + return impl(oneArg, ANY_TYPE, ANY_TYPE); + } + + @Override + List getSampleArguments() { + return List.of(DSL.literal(ANY)); + } + + @Override + String getExpected_toString() { + return "sample(ANY)"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTestBase.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTestBase.java new file mode 100644 index 0000000000..f0eca763b1 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTestBase.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; + +abstract class FunctionDSLimplTestBase extends FunctionDSLTestBase { + @Test + void implementationGenerator_is_valid() { + assertNotNull(getImplementationGenerator()); + } + + @Test + void implementation_is_valid_pair() { + assertNotNull(getImplementation().getKey()); + assertNotNull(getImplementation().getValue()); + } + + @Test + void implementation_expected_functionName() { + assertEquals(SAMPLE_NAME, getImplementation().getKey().getFunctionName()); + } + + @Test + void implementation_valid_functionBuilder() { + + FunctionBuilder v = getImplementation().getValue(); + assertDoesNotThrow(() -> v.apply(mockProperties, getSampleArguments())); + } + + @Test + void implementation_functionBuilder_return_functionExpression() { + FunctionImplementation executable = getImplementation().getValue() + .apply(mockProperties, getSampleArguments()); + assertTrue(executable instanceof FunctionExpression); + } + + @Test + void implementation_functionExpression_valueOf() { + FunctionExpression executable = + (FunctionExpression) getImplementation().getValue().apply(mockProperties, + getSampleArguments()); + + assertEquals(ANY, executable.valueOf(null)); + } + + @Test + void implementation_functionExpression_type() { + FunctionExpression executable = + (FunctionExpression) getImplementation().getValue().apply(mockProperties, + getSampleArguments()); + assertEquals(ANY_TYPE, executable.type()); + } + + @Test + void implementation_functionExpression_toString() { + FunctionExpression executable = + (FunctionExpression) getImplementation().getValue().apply(mockProperties, + getSampleArguments()); + assertEquals(getExpected_toString(), executable.toString()); + } + + /** + * A lambda that takes a function name and returns an implementation + * of the function. + */ + abstract SerializableFunction> + getImplementationGenerator(); + + Pair getImplementation() { + return getImplementationGenerator().apply(SAMPLE_NAME); + } + + abstract List getSampleArguments(); + + abstract String getExpected_toString(); +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplThreeArgTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplThreeArgTest.java new file mode 100644 index 0000000000..eab8d24c59 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplThreeArgTest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.opensearch.sql.expression.function.FunctionDSL.impl; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplThreeArgTest extends FunctionDSLimplTestBase { + + @Override + SerializableFunction> + getImplementationGenerator() { + return impl(threeArgs, ANY_TYPE, ANY_TYPE, ANY_TYPE, ANY_TYPE); + } + + @Override + List getSampleArguments() { + return List.of(DSL.literal(ANY), DSL.literal(ANY), DSL.literal(ANY)); + } + + @Override + String getExpected_toString() { + return "sample(ANY, ANY, ANY)"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTwoArgTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTwoArgTest.java new file mode 100644 index 0000000000..87d097c9eb --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTwoArgTest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.opensearch.sql.expression.function.FunctionDSL.impl; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplTwoArgTest extends FunctionDSLimplTestBase { + @Override + SerializableFunction> + getImplementationGenerator() { + return impl(twoArgs, ANY_TYPE, ANY_TYPE, ANY_TYPE); + } + + + @Override + List getSampleArguments() { + return List.of(DSL.literal(ANY), DSL.literal(ANY)); + } + + @Override + String getExpected_toString() { + return "sample(ANY, ANY)"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesNoArgsTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesNoArgsTest.java new file mode 100644 index 0000000000..c3c41b6c0c --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesNoArgsTest.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplWithPropertiesNoArgsTest extends FunctionDSLimplTestBase { + + @Override + SerializableFunction> + getImplementationGenerator() { + return FunctionDSL.implWithProperties(fp -> ANY, ANY_TYPE); + } + + @Override + List getSampleArguments() { + return List.of(); + } + + @Override + String getExpected_toString() { + return "sample()"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesOneArgTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesOneArgTest.java new file mode 100644 index 0000000000..4a05326c0a --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesOneArgTest.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplWithPropertiesOneArgTest extends FunctionDSLimplTestBase { + + @Override + SerializableFunction> + getImplementationGenerator() { + SerializableBiFunction functionBody + = (fp, arg) -> ANY; + return FunctionDSL.implWithProperties(functionBody, ANY_TYPE, ANY_TYPE); + } + + @Override + List getSampleArguments() { + return List.of(DSL.literal(ANY)); + } + + @Override + String getExpected_toString() { + return "sample(ANY)"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLnullMissingHandlingTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLnullMissingHandlingTest.java new file mode 100644 index 0000000000..706a5bba16 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLnullMissingHandlingTest.java @@ -0,0 +1,109 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling; +import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandlingWithProperties; + +import org.junit.jupiter.api.Test; + +class FunctionDSLnullMissingHandlingTest extends FunctionDSLTestBase { + + @Test + void nullMissingHandling_oneArg_nullValue() { + assertEquals(NULL, nullMissingHandling(oneArg).apply(NULL)); + } + + @Test + void nullMissingHandling_oneArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(oneArg).apply(MISSING)); + } + + @Test + void nullMissingHandling_oneArg_apply() { + assertEquals(ANY, nullMissingHandling(oneArg).apply(ANY)); + } + + + @Test + void nullMissingHandling_oneArg_FunctionProperties_nullValue() { + assertEquals(NULL, nullMissingHandlingWithProperties(oneArgWithProperties).apply(functionProperties, NULL)); + } + + @Test + void nullMissingHandling_oneArg_FunctionProperties_missingValue() { + assertEquals(MISSING, nullMissingHandlingWithProperties(oneArgWithProperties).apply(functionProperties, MISSING)); + } + + @Test + void nullMissingHandling_oneArg_FunctionProperties_apply() { + assertEquals(ANY, nullMissingHandlingWithProperties(oneArgWithProperties).apply(functionProperties, ANY)); + } + + @Test + void nullMissingHandling_twoArgs_firstArg_nullValue() { + assertEquals(NULL, nullMissingHandling(twoArgs).apply(NULL, ANY)); + } + + @Test + void nullMissingHandling_twoArgs_secondArg_nullValue() { + assertEquals(NULL, nullMissingHandling(twoArgs).apply(ANY, NULL)); + } + + + @Test + void nullMissingHandling_twoArgs_firstArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(twoArgs).apply(MISSING, ANY)); + } + + @Test + void nullMissingHandling_twoArgs_secondArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(twoArgs).apply(ANY, MISSING)); + } + + @Test + void nullMissingHandling_twoArgs_apply() { + assertEquals(ANY, nullMissingHandling(twoArgs).apply(ANY, ANY)); + } + + + @Test + void nullMissingHandling_threeArgs_firstArg_nullValue() { + assertEquals(NULL, nullMissingHandling(threeArgs).apply(NULL, ANY, ANY)); + } + + @Test + void nullMissingHandling_threeArgs_secondArg_nullValue() { + assertEquals(NULL, nullMissingHandling(threeArgs).apply(ANY, NULL, ANY)); + } + + @Test + void nullMissingHandling_threeArgs_thirdArg_nullValue() { + assertEquals(NULL, nullMissingHandling(threeArgs).apply(ANY, ANY, NULL)); + } + + + @Test + void nullMissingHandling_threeArgs_firstArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(threeArgs).apply(MISSING, ANY, ANY)); + } + + @Test + void nullMissingHandling_threeArg_secondArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(threeArgs).apply(ANY, MISSING, ANY)); + } + + @Test + void nullMissingHandling_threeArg_thirdArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(threeArgs).apply(ANY, ANY, MISSING)); + } + + @Test + void nullMissingHandling_threeArg_apply() { + assertEquals(ANY, nullMissingHandling(threeArgs).apply(ANY, ANY, ANY)); + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionPropertiesTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionPropertiesTest.java new file mode 100644 index 0000000000..3f51a05b65 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionPropertiesTest.java @@ -0,0 +1,69 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneId; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class FunctionPropertiesTest { + + FunctionProperties functionProperties; + Instant startTime; + + @BeforeEach + void init() { + startTime = Instant.now(); + functionProperties = new FunctionProperties(startTime, ZoneId.systemDefault()); + } + + @Test + void getQueryStartClock_returns_constructor_instant() { + assertEquals(startTime, functionProperties.getQueryStartClock().instant()); + } + + @Test + void getQueryStartClock_differs_from_instantNow() throws InterruptedException { + // Give system clock a chance to advance. + Thread.sleep(1000); + assertNotEquals(Instant.now(), functionProperties.getQueryStartClock().instant()); + } + + @Test + void getSystemClock_is_systemClock() { + assertEquals(Clock.systemDefaultZone(), functionProperties.getSystemClock()); + } + + @Test + void functionProperties_can_be_serialized() throws IOException { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(functionProperties); + objectOutput.flush(); + assertNotEquals(0, output.size()); + } + + @Test + void functionProperties_can_be_deserialized() throws IOException, ClassNotFoundException { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(functionProperties); + objectOutput.flush(); + ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray()); + ObjectInputStream objectInput = new ObjectInputStream(input); + assertEquals(functionProperties, objectInput.readObject()); + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/operator/convert/TypeCastOperatorTest.java b/core/src/test/java/org/opensearch/sql/expression/operator/convert/TypeCastOperatorTest.java index f791b7d86a..fcdc98ba22 100644 --- a/core/src/test/java/org/opensearch/sql/expression/operator/convert/TypeCastOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/operator/convert/TypeCastOperatorTest.java @@ -23,8 +23,10 @@ import java.util.stream.Stream; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.analysis.AnalyzerTestBase; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.model.ExprByteValue; import org.opensearch.sql.data.model.ExprDateValue; @@ -41,10 +43,16 @@ import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.config.ExpressionConfig; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; +@ExtendWith(SpringExtension.class) +@ContextConfiguration(classes = {ExpressionConfig.class, AnalyzerTestBase.class}) class TypeCastOperatorTest { - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + @Autowired + private DSL dsl; private static Stream numberData() { return Stream.of(new ExprByteValue(3), new ExprShortValue(3), diff --git a/core/src/test/java/org/opensearch/sql/expression/system/SystemFunctionsTest.java b/core/src/test/java/org/opensearch/sql/expression/system/SystemFunctionsTest.java index 453018a700..0bf68b3780 100644 --- a/core/src/test/java/org/opensearch/sql/expression/system/SystemFunctionsTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/system/SystemFunctionsTest.java @@ -16,6 +16,7 @@ import java.util.LinkedHashMap; import java.util.List; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import org.opensearch.sql.data.model.AbstractExprValue; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.model.ExprByteValue; @@ -39,9 +40,18 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.config.ExpressionConfig; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; +@ExtendWith(SpringExtension.class) +@ContextConfiguration(classes = {ExpressionConfig.class}) +@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD) public class SystemFunctionsTest { - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + + @Autowired + DSL dsl; @Test void typeof() { diff --git a/opensearch/build.gradle b/opensearch/build.gradle index 7ad7d63546..ef770353ec 100644 --- a/opensearch/build.gradle +++ b/opensearch/build.gradle @@ -45,6 +45,7 @@ dependencies { testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '3.12.4' testImplementation group: 'org.opensearch.client', name: 'opensearch-rest-high-level-client', version: "${opensearch_version}" testImplementation group: 'org.opensearch.test', name: 'framework', version: "${opensearch_version}" + testImplementation "org.springframework:spring-test:${spring_version}" } test { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/OpenSearchTestBase.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/OpenSearchTestBase.java new file mode 100644 index 0000000000..6c93ed77a1 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/OpenSearchTestBase.java @@ -0,0 +1,23 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch; + +import org.junit.jupiter.api.extension.ExtendWith; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.config.ExpressionConfig; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; + +@ExtendWith(SpringExtension.class) +@ContextConfiguration(classes = {ExpressionConfig.class}) +@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD) +public class OpenSearchTestBase { + + @Autowired + protected DSL dsl; +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataTypeRecognitionTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataTypeRecognitionTest.java index 48121baad2..38e522bc18 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataTypeRecognitionTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/data/type/OpenSearchDataTypeRecognitionTest.java @@ -13,16 +13,14 @@ import org.junit.jupiter.params.provider.MethodSource; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.expression.config.ExpressionConfig; +import org.opensearch.sql.opensearch.OpenSearchTestBase; import org.opensearch.sql.opensearch.data.value.OpenSearchExprBinaryValue; import org.opensearch.sql.opensearch.data.value.OpenSearchExprGeoPointValue; import org.opensearch.sql.opensearch.data.value.OpenSearchExprIpValue; import org.opensearch.sql.opensearch.data.value.OpenSearchExprTextKeywordValue; import org.opensearch.sql.opensearch.data.value.OpenSearchExprTextValue; -public class OpenSearchDataTypeRecognitionTest { - - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); +public class OpenSearchDataTypeRecognitionTest extends OpenSearchTestBase { @ParameterizedTest @MethodSource("types") diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java index df42a2b201..42cf4bdde3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java @@ -33,16 +33,14 @@ import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.expression.config.ExpressionConfig; +import org.opensearch.sql.opensearch.OpenSearchTestBase; import org.opensearch.sql.opensearch.utils.Utils; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; import org.opensearch.sql.storage.Table; @ExtendWith(MockitoExtension.class) -class OpenSearchLogicOptimizerTest { - - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); +class OpenSearchLogicOptimizerTest extends OpenSearchTestBase { @Mock private Table table; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java index 82ac3991ac..d28af7cb2b 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java @@ -51,13 +51,12 @@ import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; -import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.AvgAggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; -import org.opensearch.sql.expression.config.ExpressionConfig; +import org.opensearch.sql.opensearch.OpenSearchTestBase; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; @@ -73,9 +72,7 @@ import org.opensearch.sql.storage.Table; @ExtendWith(MockitoExtension.class) -class OpenSearchIndexTest { - - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); +class OpenSearchIndexTest extends OpenSearchTestBase { @Mock private OpenSearchClient client; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java index 3614d82e59..de7eca84d4 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java @@ -53,13 +53,12 @@ import org.opensearch.sql.expression.aggregation.CountAggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.config.ExpressionConfig; +import org.opensearch.sql.opensearch.OpenSearchTestBase; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) -class AggregationQueryBuilderTest { - - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); +class AggregationQueryBuilderTest extends OpenSearchTestBase { @Mock private ExpressionSerializer serializer; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScriptTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScriptTest.java index e721b7ed27..08acde104b 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScriptTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/ExpressionAggregationScriptTest.java @@ -35,12 +35,11 @@ import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.config.ExpressionConfig; +import org.opensearch.sql.opensearch.OpenSearchTestBase; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) -class ExpressionAggregationScriptTest { - - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); +class ExpressionAggregationScriptTest extends OpenSearchTestBase { @Mock private SearchLookup lookup; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index b2ad41d516..29854abc88 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -44,12 +44,12 @@ import org.opensearch.sql.expression.aggregation.TakeAggregator; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.OpenSearchTestBase; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) -class MetricAggregationBuilderTest { - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); +class MetricAggregationBuilderTest extends OpenSearchTestBase { @Mock private ExpressionSerializer serializer; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/ExpressionFilterScriptTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/ExpressionFilterScriptTest.java index c3965d8408..7d34adcaa8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/ExpressionFilterScriptTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/ExpressionFilterScriptTest.java @@ -45,12 +45,11 @@ import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.config.ExpressionConfig; +import org.opensearch.sql.opensearch.OpenSearchTestBase; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) -class ExpressionFilterScriptTest { - - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); +class ExpressionFilterScriptTest extends OpenSearchTestBase { @Mock private SearchLookup lookup; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java index 2b54e512a4..1a997731e7 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java @@ -54,13 +54,12 @@ import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.config.ExpressionConfig; +import org.opensearch.sql.opensearch.OpenSearchTestBase; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) -class FilterQueryBuilderTest { - - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); +class FilterQueryBuilderTest extends OpenSearchTestBase { private static Stream numericCastSource() { return Stream.of(literal((byte) 1), literal((short) -1), literal( @@ -314,8 +313,8 @@ void should_build_match_query_with_default_parameters() { + "}", buildQuery( dsl.match( - dsl.namedArgument("field", literal("message")), - dsl.namedArgument("query", literal("search query"))))); + DSL.namedArgument("field", literal("message")), + DSL.namedArgument("query", literal("search query"))))); } @Test @@ -342,28 +341,28 @@ void should_build_match_query_with_custom_parameters() { + "}", buildQuery( dsl.match( - dsl.namedArgument("field", literal("message")), - dsl.namedArgument("query", literal("search query")), - dsl.namedArgument("operator", literal("AND")), - dsl.namedArgument("analyzer", literal("keyword")), - dsl.namedArgument("auto_generate_synonyms_phrase_query", literal("true")), - dsl.namedArgument("fuzziness", literal("AUTO")), - dsl.namedArgument("max_expansions", literal("50")), - dsl.namedArgument("prefix_length", literal("0")), - dsl.namedArgument("fuzzy_transpositions", literal("false")), - dsl.namedArgument("fuzzy_rewrite", literal("top_terms_1")), - dsl.namedArgument("lenient", literal("false")), - dsl.namedArgument("minimum_should_match", literal("3")), - dsl.namedArgument("zero_terms_query", literal("ALL")), - dsl.namedArgument("boost", literal("2.0"))))); + DSL.namedArgument("field", literal("message")), + DSL.namedArgument("query", literal("search query")), + DSL.namedArgument("operator", literal("AND")), + DSL.namedArgument("analyzer", literal("keyword")), + DSL.namedArgument("auto_generate_synonyms_phrase_query", literal("true")), + DSL.namedArgument("fuzziness", literal("AUTO")), + DSL.namedArgument("max_expansions", literal("50")), + DSL.namedArgument("prefix_length", literal("0")), + DSL.namedArgument("fuzzy_transpositions", literal("false")), + DSL.namedArgument("fuzzy_rewrite", literal("top_terms_1")), + DSL.namedArgument("lenient", literal("false")), + DSL.namedArgument("minimum_should_match", literal("3")), + DSL.namedArgument("zero_terms_query", literal("ALL")), + DSL.namedArgument("boost", literal("2.0"))))); } @Test void match_invalid_parameter() { FunctionExpression expr = dsl.match( - dsl.namedArgument("field", literal("message")), - dsl.namedArgument("query", literal("search query")), - dsl.namedArgument("invalid_parameter", literal("invalid_value"))); + DSL.namedArgument("field", literal("message")), + DSL.namedArgument("query", literal("search query")), + DSL.namedArgument("invalid_parameter", literal("invalid_value"))); var msg = assertThrows(SemanticCheckException.class, () -> buildQuery(expr)).getMessage(); assertTrue(msg.startsWith("Parameter invalid_parameter is invalid for match function.")); } @@ -371,10 +370,10 @@ void match_invalid_parameter() { @Test void match_disallow_duplicate_parameter() { FunctionExpression expr = dsl.match( - dsl.namedArgument("field", literal("message")), - dsl.namedArgument("query", literal("search query")), - dsl.namedArgument("analyzer", literal("keyword")), - dsl.namedArgument("AnalYzer", literal("english"))); + DSL.namedArgument("field", literal("message")), + DSL.namedArgument("query", literal("search query")), + DSL.namedArgument("analyzer", literal("keyword")), + DSL.namedArgument("AnalYzer", literal("english"))); var msg = assertThrows(SemanticCheckException.class, () -> buildQuery(expr)).getMessage(); assertEquals("Parameter 'analyzer' can only be specified once.", msg); } @@ -434,8 +433,8 @@ void should_build_match_phrase_query_with_default_parameters() { + "}", buildQuery( dsl.match_phrase( - dsl.namedArgument("field", literal("message")), - dsl.namedArgument("query", literal("search query"))))); + DSL.namedArgument("field", literal("message")), + DSL.namedArgument("query", literal("search query"))))); } @Test @@ -458,10 +457,10 @@ void should_build_multi_match_query_with_default_parameters_single_field() { + " }\n" + "}", buildQuery(dsl.multi_match( - dsl.namedArgument("fields", DSL.literal(new ExprTupleValue( + DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F)))))), - dsl.namedArgument("query", literal("search query"))))); + DSL.namedArgument("query", literal("search query"))))); } @Test @@ -484,10 +483,10 @@ void should_build_multi_match_query_with_default_parameters_all_fields() { + " }\n" + "}", buildQuery(dsl.multi_match( - dsl.namedArgument("fields", DSL.literal(new ExprTupleValue( + DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "*", ExprValueUtils.floatValue(1.F)))))), - dsl.namedArgument("query", literal("search query"))))); + DSL.namedArgument("query", literal("search query"))))); } @Test @@ -508,9 +507,9 @@ void should_build_multi_match_query_with_default_parameters_no_fields() { + " }\n" + "}", buildQuery(dsl.multi_match( - dsl.namedArgument("fields", DSL.literal(new ExprTupleValue( + DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of())))), - dsl.namedArgument("query", literal("search query"))))); + DSL.namedArgument("query", literal("search query"))))); } // Note: we can't test `multi_match` and `simple_query_string` without weight(s) @@ -533,11 +532,11 @@ void should_build_multi_match_query_with_default_parameters_multiple_fields() { + " }\n" + "}"; var actual = buildQuery(dsl.multi_match( - dsl.namedArgument("fields", DSL.literal(new ExprTupleValue( + DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), "field2", ExprValueUtils.floatValue(.3F)))))), - dsl.namedArgument("query", literal("search query")))); + DSL.namedArgument("query", literal("search query")))); var ex1 = String.format(expected, "\"field1^1.0\", \"field2^0.3\""); var ex2 = String.format(expected, "\"field2^0.3\", \"field1^1.0\""); @@ -571,24 +570,24 @@ void should_build_multi_match_query_with_custom_parameters() { + "}"; var actual = buildQuery( dsl.multi_match( - dsl.namedArgument("fields", DSL.literal( + DSL.namedArgument("fields", DSL.literal( ExprValueUtils.tupleValue(ImmutableMap.of("field1", 1.F, "field2", .3F)))), - dsl.namedArgument("query", literal("search query")), - dsl.namedArgument("analyzer", literal("keyword")), - dsl.namedArgument("auto_generate_synonyms_phrase_query", literal("false")), - dsl.namedArgument("cutoff_frequency", literal("4.3")), - dsl.namedArgument("fuzziness", literal("AUTO:2,4")), - dsl.namedArgument("fuzzy_transpositions", literal("false")), - dsl.namedArgument("lenient", literal("false")), - dsl.namedArgument("max_expansions", literal("3")), - dsl.namedArgument("minimum_should_match", literal("3")), - dsl.namedArgument("operator", literal("AND")), - dsl.namedArgument("prefix_length", literal("1")), - dsl.namedArgument("slop", literal("1")), - dsl.namedArgument("tie_breaker", literal("1")), - dsl.namedArgument("type", literal("phrase_prefix")), - dsl.namedArgument("zero_terms_query", literal("ALL")), - dsl.namedArgument("boost", literal("2.0")))); + DSL.namedArgument("query", literal("search query")), + DSL.namedArgument("analyzer", literal("keyword")), + DSL.namedArgument("auto_generate_synonyms_phrase_query", literal("false")), + DSL.namedArgument("cutoff_frequency", literal("4.3")), + DSL.namedArgument("fuzziness", literal("AUTO:2,4")), + DSL.namedArgument("fuzzy_transpositions", literal("false")), + DSL.namedArgument("lenient", literal("false")), + DSL.namedArgument("max_expansions", literal("3")), + DSL.namedArgument("minimum_should_match", literal("3")), + DSL.namedArgument("operator", literal("AND")), + DSL.namedArgument("prefix_length", literal("1")), + DSL.namedArgument("slop", literal("1")), + DSL.namedArgument("tie_breaker", literal("1")), + DSL.namedArgument("type", literal("phrase_prefix")), + DSL.namedArgument("zero_terms_query", literal("ALL")), + DSL.namedArgument("boost", literal("2.0")))); var ex1 = String.format(expected, "\"field1^1.0\", \"field2^0.3\""); var ex2 = String.format(expected, "\"field2^0.3\", \"field1^1.0\""); @@ -600,12 +599,12 @@ void should_build_multi_match_query_with_custom_parameters() { @Test void multi_match_invalid_parameter() { FunctionExpression expr = dsl.multi_match( - dsl.namedArgument("fields", DSL.literal( + DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), "field2", ExprValueUtils.floatValue(.3F)))))), - dsl.namedArgument("query", literal("search query")), - dsl.namedArgument("invalid_parameter", literal("invalid_value"))); + DSL.namedArgument("query", literal("search query")), + DSL.namedArgument("invalid_parameter", literal("invalid_value"))); assertThrows(SemanticCheckException.class, () -> buildQuery(expr), "Parameter invalid_parameter is invalid for match function."); } @@ -626,12 +625,12 @@ void should_build_match_phrase_query_with_custom_parameters() { + "}", buildQuery( dsl.match_phrase( - dsl.namedArgument("boost", literal("1.2")), - dsl.namedArgument("field", literal("message")), - dsl.namedArgument("query", literal("search query")), - dsl.namedArgument("analyzer", literal("keyword")), - dsl.namedArgument("slop", literal("2")), - dsl.namedArgument("zero_terms_query", literal("ALL"))))); + DSL.namedArgument("boost", literal("1.2")), + DSL.namedArgument("field", literal("message")), + DSL.namedArgument("query", literal("search query")), + DSL.namedArgument("analyzer", literal("keyword")), + DSL.namedArgument("slop", literal("2")), + DSL.namedArgument("zero_terms_query", literal("ALL"))))); } @Test @@ -724,12 +723,12 @@ void should_build_query_query_with_custom_parameters() { @Test void query_string_invalid_parameter() { FunctionExpression expr = dsl.query_string( - dsl.namedArgument("fields", DSL.literal( + DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), "field2", ExprValueUtils.floatValue(.3F)))))), - dsl.namedArgument("query", literal("search query")), - dsl.namedArgument("invalid_parameter", literal("invalid_value"))); + DSL.namedArgument("query", literal("search query")), + DSL.namedArgument("invalid_parameter", literal("invalid_value"))); assertThrows(SemanticCheckException.class, () -> buildQuery(expr), "Parameter invalid_parameter is invalid for match function."); } @@ -755,11 +754,11 @@ void should_build_query_string_query_with_default_parameters_multiple_fields() { + " }\n" + "}"; var actual = buildQuery(dsl.query_string( - dsl.namedArgument("fields", DSL.literal(new ExprTupleValue( + DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), "field2", ExprValueUtils.floatValue(.3F)))))), - dsl.namedArgument("query", literal("query_value")))); + DSL.namedArgument("query", literal("query_value")))); var ex1 = String.format(expected, "\"field1^1.0\", \"field2^0.3\""); var ex2 = String.format(expected, "\"field2^0.3\", \"field1^1.0\""); @@ -795,21 +794,21 @@ void should_build_query_string_query_with_custom_parameters() { + "}"; var actual = buildQuery( dsl.query_string( - dsl.namedArgument("fields", DSL.literal( + DSL.namedArgument("fields", DSL.literal( ExprValueUtils.tupleValue(ImmutableMap.of("field1", 1.F, "field2", .3F)))), - dsl.namedArgument("query", literal("query_value")), - dsl.namedArgument("analyze_wildcard", literal("true")), - dsl.namedArgument("analyzer", literal("keyword")), - dsl.namedArgument("auto_generate_synonyms_phrase_query", literal("false")), - dsl.namedArgument("default_operator", literal("AND")), - dsl.namedArgument("fuzzy_max_expansions", literal("10")), - dsl.namedArgument("fuzzy_prefix_length", literal("2")), - dsl.namedArgument("fuzzy_transpositions", literal("false")), - dsl.namedArgument("lenient", literal("false")), - dsl.namedArgument("minimum_should_match", literal("3")), - dsl.namedArgument("tie_breaker", literal("1.3")), - dsl.namedArgument("type", literal("cross_fields")), - dsl.namedArgument("boost", literal("2.0")))); + DSL.namedArgument("query", literal("query_value")), + DSL.namedArgument("analyze_wildcard", literal("true")), + DSL.namedArgument("analyzer", literal("keyword")), + DSL.namedArgument("auto_generate_synonyms_phrase_query", literal("false")), + DSL.namedArgument("default_operator", literal("AND")), + DSL.namedArgument("fuzzy_max_expansions", literal("10")), + DSL.namedArgument("fuzzy_prefix_length", literal("2")), + DSL.namedArgument("fuzzy_transpositions", literal("false")), + DSL.namedArgument("lenient", literal("false")), + DSL.namedArgument("minimum_should_match", literal("3")), + DSL.namedArgument("tie_breaker", literal("1.3")), + DSL.namedArgument("type", literal("cross_fields")), + DSL.namedArgument("boost", literal("2.0")))); var ex1 = String.format(expected, "\"field1^1.0\", \"field2^0.3\""); var ex2 = String.format(expected, "\"field2^0.3\", \"field1^1.0\""); @@ -841,10 +840,10 @@ void should_build_query_string_query_with_default_parameters_single_field() { + " }\n" + "}", buildQuery(dsl.query_string( - dsl.namedArgument("fields", DSL.literal(new ExprTupleValue( + DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F)))))), - dsl.namedArgument("query", literal("query_value"))))); + DSL.namedArgument("query", literal("query_value"))))); } @Test @@ -870,10 +869,10 @@ void should_build_simple_query_string_query_with_default_parameters_single_field + " }\n" + "}", buildQuery(dsl.simple_query_string( - dsl.namedArgument("fields", DSL.literal(new ExprTupleValue( + DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F)))))), - dsl.namedArgument("query", literal("search query"))))); + DSL.namedArgument("query", literal("search query"))))); } @Test @@ -893,11 +892,11 @@ void should_build_simple_query_string_query_with_default_parameters_multiple_fie + " }\n" + "}"; var actual = buildQuery(dsl.simple_query_string( - dsl.namedArgument("fields", DSL.literal(new ExprTupleValue( + DSL.namedArgument("fields", DSL.literal(new ExprTupleValue( new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), "field2", ExprValueUtils.floatValue(.3F)))))), - dsl.namedArgument("query", literal("search query")))); + DSL.namedArgument("query", literal("search query")))); var ex1 = String.format(expected, "\"field1^1.0\", \"field2^0.3\""); var ex2 = String.format(expected, "\"field2^0.3\", \"field1^1.0\""); @@ -927,20 +926,20 @@ void should_build_simple_query_string_query_with_custom_parameters() { + "}"; var actual = buildQuery( dsl.simple_query_string( - dsl.namedArgument("fields", DSL.literal( + DSL.namedArgument("fields", DSL.literal( ExprValueUtils.tupleValue(ImmutableMap.of("field1", 1.F, "field2", .3F)))), - dsl.namedArgument("query", literal("search query")), - dsl.namedArgument("analyze_wildcard", literal("true")), - dsl.namedArgument("analyzer", literal("keyword")), - dsl.namedArgument("auto_generate_synonyms_phrase_query", literal("false")), - dsl.namedArgument("default_operator", literal("AND")), - dsl.namedArgument("flags", literal("AND")), - dsl.namedArgument("fuzzy_max_expansions", literal("10")), - dsl.namedArgument("fuzzy_prefix_length", literal("2")), - dsl.namedArgument("fuzzy_transpositions", literal("false")), - dsl.namedArgument("lenient", literal("false")), - dsl.namedArgument("minimum_should_match", literal("3")), - dsl.namedArgument("boost", literal("2.0")))); + DSL.namedArgument("query", literal("search query")), + DSL.namedArgument("analyze_wildcard", literal("true")), + DSL.namedArgument("analyzer", literal("keyword")), + DSL.namedArgument("auto_generate_synonyms_phrase_query", literal("false")), + DSL.namedArgument("default_operator", literal("AND")), + DSL.namedArgument("flags", literal("AND")), + DSL.namedArgument("fuzzy_max_expansions", literal("10")), + DSL.namedArgument("fuzzy_prefix_length", literal("2")), + DSL.namedArgument("fuzzy_transpositions", literal("false")), + DSL.namedArgument("lenient", literal("false")), + DSL.namedArgument("minimum_should_match", literal("3")), + DSL.namedArgument("boost", literal("2.0")))); var ex1 = String.format(expected, "\"field1^1.0\", \"field2^0.3\""); var ex2 = String.format(expected, "\"field2^0.3\", \"field1^1.0\""); @@ -952,12 +951,12 @@ void should_build_simple_query_string_query_with_custom_parameters() { @Test void simple_query_string_invalid_parameter() { FunctionExpression expr = dsl.simple_query_string( - dsl.namedArgument("fields", DSL.literal( + DSL.namedArgument("fields", DSL.literal( new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( "field1", ExprValueUtils.floatValue(1.F), "field2", ExprValueUtils.floatValue(.3F)))))), - dsl.namedArgument("query", literal("search query")), - dsl.namedArgument("invalid_parameter", literal("invalid_value"))); + DSL.namedArgument("query", literal("search query")), + DSL.namedArgument("invalid_parameter", literal("invalid_value"))); assertThrows(SemanticCheckException.class, () -> buildQuery(expr), "Parameter invalid_parameter is invalid for match function."); } @@ -965,9 +964,9 @@ void simple_query_string_invalid_parameter() { @Test void match_phrase_invalid_parameter() { FunctionExpression expr = dsl.match_phrase( - dsl.namedArgument("field", literal("message")), - dsl.namedArgument("query", literal("search query")), - dsl.namedArgument("invalid_parameter", literal("invalid_value"))); + DSL.namedArgument("field", literal("message")), + DSL.namedArgument("query", literal("search query")), + DSL.namedArgument("invalid_parameter", literal("invalid_value"))); var msg = assertThrows(SemanticCheckException.class, () -> buildQuery(expr)).getMessage(); assertTrue(msg.startsWith("Parameter invalid_parameter is invalid for match_phrase function.")); } @@ -1052,8 +1051,8 @@ void should_build_match_bool_prefix_query_with_default_parameters() { + "}", buildQuery( dsl.match_bool_prefix( - dsl.namedArgument("field", literal("message")), - dsl.namedArgument("query", literal("search query"))))); + DSL.namedArgument("field", literal("message")), + DSL.namedArgument("query", literal("search query"))))); } @Test @@ -1104,8 +1103,8 @@ void should_build_match_phrase_prefix_query_with_default_parameters() { + "}", buildQuery( dsl.match_phrase_prefix( - dsl.namedArgument("field", literal("message")), - dsl.namedArgument("query", literal("search query"))))); + DSL.namedArgument("field", literal("message")), + DSL.namedArgument("query", literal("search query"))))); } @Test @@ -1125,11 +1124,11 @@ void should_build_match_phrase_prefix_query_with_non_default_parameters() { + "}", buildQuery( dsl.match_phrase_prefix( - dsl.namedArgument("field", literal("message")), - dsl.namedArgument("query", literal("search query")), - dsl.namedArgument("boost", literal("1.2")), - dsl.namedArgument("max_expansions", literal("42")), - dsl.namedArgument("analyzer", literal("english"))))); + DSL.namedArgument("field", literal("message")), + DSL.namedArgument("query", literal("search query")), + DSL.namedArgument("boost", literal("1.2")), + DSL.namedArgument("max_expansions", literal("42")), + DSL.namedArgument("analyzer", literal("english"))))); } @Test diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQueryTest.java index ace10a019f..4dc91d2793 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LuceneQueryTest.java @@ -15,13 +15,13 @@ import org.junit.jupiter.api.Test; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.config.ExpressionConfig; +import org.opensearch.sql.opensearch.OpenSearchTestBase; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -class LuceneQueryTest { +class LuceneQueryTest extends OpenSearchTestBase { @Test void should_not_support_single_argument_by_default() { - DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); assertFalse(new LuceneQuery(){}.canSupport(dsl.abs(DSL.ref("age", INTEGER)))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java index c30e06bc1a..9a7e599acf 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchBoolPrefixQueryTest.java @@ -26,29 +26,28 @@ import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.OpenSearchTestBase; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchBoolPrefixQuery; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class MatchBoolPrefixQueryTest { - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); +public class MatchBoolPrefixQueryTest extends OpenSearchTestBase { private final MatchBoolPrefixQuery matchBoolPrefixQuery = new MatchBoolPrefixQuery(); private final FunctionName matchBoolPrefix = FunctionName.of("match_bool_prefix"); static Stream> generateValidData() { - final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); - NamedArgumentExpression field = dsl.namedArgument("field", DSL.literal("field_value")); - NamedArgumentExpression query = dsl.namedArgument("query", DSL.literal("query_value")); + NamedArgumentExpression field = DSL.namedArgument("field", DSL.literal("field_value")); + NamedArgumentExpression query = DSL.namedArgument("query", DSL.literal("query_value")); return List.of( - dsl.namedArgument("fuzziness", DSL.literal("AUTO")), - dsl.namedArgument("max_expansions", DSL.literal("50")), - dsl.namedArgument("prefix_length", DSL.literal("0")), - dsl.namedArgument("fuzzy_transpositions", DSL.literal("true")), - dsl.namedArgument("fuzzy_rewrite", DSL.literal("constant_score")), - dsl.namedArgument("minimum_should_match", DSL.literal("3")), - dsl.namedArgument("boost", DSL.literal("1")), - dsl.namedArgument("analyzer", DSL.literal("simple")), - dsl.namedArgument("operator", DSL.literal("Or")), - dsl.namedArgument("operator", DSL.literal("and")) + DSL.namedArgument("fuzziness", DSL.literal("AUTO")), + DSL.namedArgument("max_expansions", DSL.literal("50")), + DSL.namedArgument("prefix_length", DSL.literal("0")), + DSL.namedArgument("fuzzy_transpositions", DSL.literal("true")), + DSL.namedArgument("fuzzy_rewrite", DSL.literal("constant_score")), + DSL.namedArgument("minimum_should_match", DSL.literal("3")), + DSL.namedArgument("boost", DSL.literal("1")), + DSL.namedArgument("analyzer", DSL.literal("simple")), + DSL.namedArgument("operator", DSL.literal("Or")), + DSL.namedArgument("operator", DSL.literal("and")) ).stream().map(arg -> List.of(field, query, arg)); } @@ -61,8 +60,8 @@ public void test_valid_arguments(List validArgs) { @Test public void test_valid_when_two_arguments() { List arguments = List.of( - dsl.namedArgument("field", "field_value"), - dsl.namedArgument("query", "query_value")); + DSL.namedArgument("field", "field_value"), + DSL.namedArgument("query", "query_value")); Assertions.assertNotNull(matchBoolPrefixQuery.build(new MatchExpression(arguments))); } @@ -75,7 +74,7 @@ public void test_SyntaxCheckException_when_no_arguments() { @Test public void test_SyntaxCheckException_when_one_argument() { - List arguments = List.of(dsl.namedArgument("field", "field_value")); + List arguments = List.of(DSL.namedArgument("field", "field_value")); assertThrows(SyntaxCheckException.class, () -> matchBoolPrefixQuery.build(new MatchExpression(arguments))); } @@ -83,9 +82,9 @@ public void test_SyntaxCheckException_when_one_argument() { @Test public void test_SemanticCheckException_when_invalid_argument() { List arguments = List.of( - dsl.namedArgument("field", "field_value"), - dsl.namedArgument("query", "query_value"), - dsl.namedArgument("unsupported", "unsupported_value")); + DSL.namedArgument("field", "field_value"), + DSL.namedArgument("query", "query_value"), + DSL.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> matchBoolPrefixQuery.build(new MatchExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java index 8e1a2fcdf0..c4a23ae10b 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java @@ -23,12 +23,12 @@ import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.OpenSearchTestBase; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchPhrasePrefixQuery; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class MatchPhrasePrefixQueryTest { +public class MatchPhrasePrefixQueryTest extends OpenSearchTestBase { - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); private final MatchPhrasePrefixQuery matchPhrasePrefixQuery = new MatchPhrasePrefixQuery(); private final FunctionName matchPhrasePrefix = FunctionName.of("match_phrase_prefix"); @@ -41,7 +41,7 @@ public void test_SyntaxCheckException_when_no_arguments() { @Test public void test_SyntaxCheckException_when_one_argument() { - List arguments = List.of(dsl.namedArgument("field", "test")); + List arguments = List.of(DSL.namedArgument("field", "test")); assertThrows(SyntaxCheckException.class, () -> matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -49,9 +49,9 @@ public void test_SyntaxCheckException_when_one_argument() { @Test public void test_SyntaxCheckException_when_invalid_parameter() { List arguments = List.of( - dsl.namedArgument("field", "test"), - dsl.namedArgument("query", "test2"), - dsl.namedArgument("unsupported", "3")); + DSL.namedArgument("field", "test"), + DSL.namedArgument("query", "test2"), + DSL.namedArgument("unsupported", "3")); Assertions.assertThrows(SemanticCheckException.class, () -> matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -59,9 +59,9 @@ public void test_SyntaxCheckException_when_invalid_parameter() { @Test public void test_analyzer_parameter() { List arguments = List.of( - dsl.namedArgument("field", "t1"), - dsl.namedArgument("query", "t2"), - dsl.namedArgument("analyzer", "standard") + DSL.namedArgument("field", "t1"), + DSL.namedArgument("query", "t2"), + DSL.namedArgument("analyzer", "standard") ); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -69,17 +69,17 @@ public void test_analyzer_parameter() { @Test public void build_succeeds_with_two_arguments() { List arguments = List.of( - dsl.namedArgument("field", "test"), - dsl.namedArgument("query", "test2")); + DSL.namedArgument("field", "test"), + DSL.namedArgument("query", "test2")); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @Test public void test_slop_parameter() { List arguments = List.of( - dsl.namedArgument("field", "t1"), - dsl.namedArgument("query", "t2"), - dsl.namedArgument("slop", "2") + DSL.namedArgument("field", "t1"), + DSL.namedArgument("query", "t2"), + DSL.namedArgument("slop", "2") ); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -87,9 +87,9 @@ public void test_slop_parameter() { @Test public void test_zero_terms_query_parameter() { List arguments = List.of( - dsl.namedArgument("field", "t1"), - dsl.namedArgument("query", "t2"), - dsl.namedArgument("zero_terms_query", "ALL") + DSL.namedArgument("field", "t1"), + DSL.namedArgument("query", "t2"), + DSL.namedArgument("zero_terms_query", "ALL") ); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -97,9 +97,9 @@ public void test_zero_terms_query_parameter() { @Test public void test_zero_terms_query_parameter_lower_case() { List arguments = List.of( - dsl.namedArgument("field", "t1"), - dsl.namedArgument("query", "t2"), - dsl.namedArgument("zero_terms_query", "all") + DSL.namedArgument("field", "t1"), + DSL.namedArgument("query", "t2"), + DSL.namedArgument("zero_terms_query", "all") ); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } @@ -107,9 +107,9 @@ public void test_zero_terms_query_parameter_lower_case() { @Test public void test_boost_parameter() { List arguments = List.of( - dsl.namedArgument("field", "t1"), - dsl.namedArgument("query", "t2"), - dsl.namedArgument("boost", "0.1") + DSL.namedArgument("field", "t1"), + DSL.namedArgument("query", "t2"), + DSL.namedArgument("boost", "0.1") ); Assertions.assertNotNull(matchPhrasePrefixQuery.build(new MatchPhraseExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java index 09e25fe569..b951d6825a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhraseQueryTest.java @@ -23,12 +23,12 @@ import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.OpenSearchTestBase; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchPhraseQuery; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class MatchPhraseQueryTest { +public class MatchPhraseQueryTest extends OpenSearchTestBase { - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); private final MatchPhraseQuery matchPhraseQuery = new MatchPhraseQuery(); private final FunctionName matchPhrase = FunctionName.of("match_phrase"); @@ -41,7 +41,7 @@ public void test_SyntaxCheckException_when_no_arguments() { @Test public void test_SyntaxCheckException_when_one_argument() { - List arguments = List.of(dsl.namedArgument("field", "test")); + List arguments = List.of(DSL.namedArgument("field", "test")); assertThrows(SyntaxCheckException.class, () -> matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -49,9 +49,9 @@ public void test_SyntaxCheckException_when_one_argument() { @Test public void test_SyntaxCheckException_when_invalid_parameter() { List arguments = List.of( - dsl.namedArgument("field", "test"), - dsl.namedArgument("query", "test2"), - dsl.namedArgument("unsupported", "3")); + DSL.namedArgument("field", "test"), + DSL.namedArgument("query", "test2"), + DSL.namedArgument("unsupported", "3")); Assertions.assertThrows(SemanticCheckException.class, () -> matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -59,9 +59,9 @@ public void test_SyntaxCheckException_when_invalid_parameter() { @Test public void test_analyzer_parameter() { List arguments = List.of( - dsl.namedArgument("field", "t1"), - dsl.namedArgument("query", "t2"), - dsl.namedArgument("analyzer", "standard") + DSL.namedArgument("field", "t1"), + DSL.namedArgument("query", "t2"), + DSL.namedArgument("analyzer", "standard") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -69,17 +69,17 @@ public void test_analyzer_parameter() { @Test public void build_succeeds_with_two_arguments() { List arguments = List.of( - dsl.namedArgument("field", "test"), - dsl.namedArgument("query", "test2")); + DSL.namedArgument("field", "test"), + DSL.namedArgument("query", "test2")); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @Test public void test_slop_parameter() { List arguments = List.of( - dsl.namedArgument("field", "t1"), - dsl.namedArgument("query", "t2"), - dsl.namedArgument("slop", "2") + DSL.namedArgument("field", "t1"), + DSL.namedArgument("query", "t2"), + DSL.namedArgument("slop", "2") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -87,9 +87,9 @@ public void test_slop_parameter() { @Test public void test_zero_terms_query_parameter() { List arguments = List.of( - dsl.namedArgument("field", "t1"), - dsl.namedArgument("query", "t2"), - dsl.namedArgument("zero_terms_query", "ALL") + DSL.namedArgument("field", "t1"), + DSL.namedArgument("query", "t2"), + DSL.namedArgument("zero_terms_query", "ALL") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } @@ -97,9 +97,9 @@ public void test_zero_terms_query_parameter() { @Test public void test_zero_terms_query_parameter_lower_case() { List arguments = List.of( - dsl.namedArgument("field", "t1"), - dsl.namedArgument("query", "t2"), - dsl.namedArgument("zero_terms_query", "all") + DSL.namedArgument("field", "t1"), + DSL.namedArgument("query", "t2"), + DSL.namedArgument("zero_terms_query", "all") ); Assertions.assertNotNull(matchPhraseQuery.build(new MatchPhraseExpression(arguments))); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchQueryTest.java index 60dd938f78..87020e5d8f 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchQueryTest.java @@ -23,114 +23,112 @@ import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.NamedArgumentExpression; -import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.OpenSearchTestBase; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchQuery; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class MatchQueryTest { - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); +class MatchQueryTest extends OpenSearchTestBase { private final MatchQuery matchQuery = new MatchQuery(); private final FunctionName match = FunctionName.of("match"); static Stream> generateValidData() { - final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); return Stream.of( List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("analyzer", DSL.literal("standard")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("analyzer", DSL.literal("standard")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("auto_generate_synonyms_phrase_query", DSL.literal("true")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("auto_generate_synonyms_phrase_query", DSL.literal("true")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("fuzziness", DSL.literal("AUTO")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("fuzziness", DSL.literal("AUTO")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("max_expansions", DSL.literal("50")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("max_expansions", DSL.literal("50")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("prefix_length", DSL.literal("0")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("prefix_length", DSL.literal("0")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("fuzzy_transpositions", DSL.literal("true")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("fuzzy_transpositions", DSL.literal("true")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("fuzzy_rewrite", DSL.literal("constant_score")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("fuzzy_rewrite", DSL.literal("constant_score")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("lenient", DSL.literal("false")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("lenient", DSL.literal("false")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("operator", DSL.literal("OR")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("operator", DSL.literal("OR")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("minimum_should_match", DSL.literal("3")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("minimum_should_match", DSL.literal("3")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("zero_terms_query", DSL.literal("NONE")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("zero_terms_query", DSL.literal("NONE")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("zero_terms_query", DSL.literal("none")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("zero_terms_query", DSL.literal("none")) ), List.of( - dsl.namedArgument("field", DSL.literal("field_value")), - dsl.namedArgument("query", DSL.literal("query_value")), - dsl.namedArgument("boost", DSL.literal("1")) + DSL.namedArgument("field", DSL.literal("field_value")), + DSL.namedArgument("query", DSL.literal("query_value")), + DSL.namedArgument("boost", DSL.literal("1")) ) ); } @ParameterizedTest @MethodSource("generateValidData") - public void test_valid_parameters(List validArgs) { + void test_valid_parameters(List validArgs) { Assertions.assertNotNull(matchQuery.build(new MatchExpression(validArgs))); } @Test - public void test_SyntaxCheckException_when_no_arguments() { + void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); assertThrows(SyntaxCheckException.class, () -> matchQuery.build(new MatchExpression(arguments))); } @Test - public void test_SyntaxCheckException_when_one_argument() { + void test_SyntaxCheckException_when_one_argument() { List arguments = List.of(namedArgument("field", "field_value")); assertThrows(SyntaxCheckException.class, () -> matchQuery.build(new MatchExpression(arguments))); } @Test - public void test_SemanticCheckException_when_invalid_parameter() { + void test_SemanticCheckException_when_invalid_parameter() { List arguments = List.of( namedArgument("field", "field_value"), namedArgument("query", "query_value"), @@ -140,7 +138,7 @@ public void test_SemanticCheckException_when_invalid_parameter() { } private NamedArgumentExpression namedArgument(String name, String value) { - return dsl.namedArgument(name, DSL.literal(value)); + return DSL.namedArgument(name, DSL.literal(value)); } private class MatchExpression extends FunctionExpression { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java index 748384f4c8..205b11fe21 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MultiMatchTest.java @@ -36,8 +36,6 @@ @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class MultiMatchTest { - private static final DSL dsl = new ExpressionConfig() - .dsl(new ExpressionConfig().functionRepository()); private final MultiMatchQuery multiMatchQuery = new MultiMatchQuery(); private final FunctionName multiMatch = FunctionName.of("multi_match"); private static final LiteralExpression fields_value = DSL.literal( @@ -49,83 +47,83 @@ class MultiMatchTest { static Stream> generateValidData() { return Stream.of( List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("analyzer", DSL.literal("simple")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("analyzer", DSL.literal("simple")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("auto_generate_synonyms_phrase_query", DSL.literal("true")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("auto_generate_synonyms_phrase_query", DSL.literal("true")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("boost", DSL.literal("1.3")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("boost", DSL.literal("1.3")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("cutoff_frequency", DSL.literal("4.2")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("cutoff_frequency", DSL.literal("4.2")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("fuzziness", DSL.literal("AUTO:2,4")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("fuzziness", DSL.literal("AUTO:2,4")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("fuzzy_transpositions", DSL.literal("true")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("fuzzy_transpositions", DSL.literal("true")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("lenient", DSL.literal("true")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("lenient", DSL.literal("true")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("max_expansions", DSL.literal("7")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("max_expansions", DSL.literal("7")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("minimum_should_match", DSL.literal("4")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("minimum_should_match", DSL.literal("4")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("operator", DSL.literal("AND")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("operator", DSL.literal("AND")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("prefix_length", DSL.literal("7")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("prefix_length", DSL.literal("7")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("tie_breaker", DSL.literal("0.3")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("tie_breaker", DSL.literal("0.3")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("type", DSL.literal("cross_fields")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("type", DSL.literal("cross_fields")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("zero_terms_query", DSL.literal("ALL")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("zero_terms_query", DSL.literal("ALL")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("zero_terms_query", DSL.literal("all")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("zero_terms_query", DSL.literal("all")) ) ); } @@ -156,13 +154,13 @@ public void test_SemanticCheckException_when_invalid_parameter() { List arguments = List.of( namedArgument("fields", fields_value), namedArgument("query", query_value), - dsl.namedArgument("unsupported", "unsupported_value")); + DSL.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> multiMatchQuery.build(new MultiMatchExpression(arguments))); } private NamedArgumentExpression namedArgument(String name, LiteralExpression value) { - return dsl.namedArgument(name, value); + return DSL.namedArgument(name, value); } private class MultiMatchExpression extends FunctionExpression { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java index 4692f046db..c13b37bd85 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java @@ -28,15 +28,13 @@ import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedArgumentExpression; -import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.OpenSearchTestBase; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.QueryStringQuery; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -class QueryStringTest { - private static final DSL dsl = new ExpressionConfig() - .dsl(new ExpressionConfig().functionRepository()); +class QueryStringTest extends OpenSearchTestBase { private final QueryStringQuery queryStringQuery = new QueryStringQuery(); private final FunctionName queryStringFunc = FunctionName.of("query_string"); private static final LiteralExpression fields_value = DSL.literal( @@ -46,64 +44,64 @@ class QueryStringTest { private static final LiteralExpression query_value = DSL.literal("query_value"); static Stream> generateValidData() { - Expression field = dsl.namedArgument("fields", fields_value); - Expression query = dsl.namedArgument("query", query_value); - return List.of( - dsl.namedArgument("analyzer", DSL.literal("standard")), - dsl.namedArgument("analyze_wildcard", DSL.literal("true")), - dsl.namedArgument("allow_leading_wildcard", DSL.literal("true")), - dsl.namedArgument("auto_generate_synonyms_phrase_query", DSL.literal("true")), - dsl.namedArgument("boost", DSL.literal("1")), - dsl.namedArgument("default_operator", DSL.literal("AND")), - dsl.namedArgument("default_operator", DSL.literal("and")), - dsl.namedArgument("enable_position_increments", DSL.literal("true")), - dsl.namedArgument("escape", DSL.literal("false")), - dsl.namedArgument("fuzziness", DSL.literal("1")), - dsl.namedArgument("fuzzy_rewrite", DSL.literal("constant_score")), - dsl.namedArgument("fuzzy_max_expansions", DSL.literal("42")), - dsl.namedArgument("fuzzy_prefix_length", DSL.literal("42")), - dsl.namedArgument("fuzzy_transpositions", DSL.literal("true")), - dsl.namedArgument("lenient", DSL.literal("true")), - dsl.namedArgument("max_determinized_states", DSL.literal("10000")), - dsl.namedArgument("minimum_should_match", DSL.literal("4")), - dsl.namedArgument("quote_analyzer", DSL.literal("standard")), - dsl.namedArgument("phrase_slop", DSL.literal("0")), - dsl.namedArgument("quote_field_suffix", DSL.literal(".exact")), - dsl.namedArgument("rewrite", DSL.literal("constant_score")), - dsl.namedArgument("type", DSL.literal("best_fields")), - dsl.namedArgument("tie_breaker", DSL.literal("0.3")), - dsl.namedArgument("time_zone", DSL.literal("Canada/Pacific")), - dsl.namedArgument("ANALYZER", DSL.literal("standard")), - dsl.namedArgument("ANALYZE_wildcard", DSL.literal("true")), - dsl.namedArgument("Allow_Leading_wildcard", DSL.literal("true")), - dsl.namedArgument("Auto_Generate_Synonyms_Phrase_Query", DSL.literal("true")), - dsl.namedArgument("Boost", DSL.literal("1")) - ).stream().map(arg -> List.of(field, query, arg)); + Expression field = DSL.namedArgument("fields", fields_value); + Expression query = DSL.namedArgument("query", query_value); + return Stream.of( + DSL.namedArgument("analyzer", DSL.literal("standard")), + DSL.namedArgument("analyze_wildcard", DSL.literal("true")), + DSL.namedArgument("allow_leading_wildcard", DSL.literal("true")), + DSL.namedArgument("auto_generate_synonyms_phrase_query", DSL.literal("true")), + DSL.namedArgument("boost", DSL.literal("1")), + DSL.namedArgument("default_operator", DSL.literal("AND")), + DSL.namedArgument("default_operator", DSL.literal("and")), + DSL.namedArgument("enable_position_increments", DSL.literal("true")), + DSL.namedArgument("escape", DSL.literal("false")), + DSL.namedArgument("fuzziness", DSL.literal("1")), + DSL.namedArgument("fuzzy_rewrite", DSL.literal("constant_score")), + DSL.namedArgument("fuzzy_max_expansions", DSL.literal("42")), + DSL.namedArgument("fuzzy_prefix_length", DSL.literal("42")), + DSL.namedArgument("fuzzy_transpositions", DSL.literal("true")), + DSL.namedArgument("lenient", DSL.literal("true")), + DSL.namedArgument("max_determinized_states", DSL.literal("10000")), + DSL.namedArgument("minimum_should_match", DSL.literal("4")), + DSL.namedArgument("quote_analyzer", DSL.literal("standard")), + DSL.namedArgument("phrase_slop", DSL.literal("0")), + DSL.namedArgument("quote_field_suffix", DSL.literal(".exact")), + DSL.namedArgument("rewrite", DSL.literal("constant_score")), + DSL.namedArgument("type", DSL.literal("best_fields")), + DSL.namedArgument("tie_breaker", DSL.literal("0.3")), + DSL.namedArgument("time_zone", DSL.literal("Canada/Pacific")), + DSL.namedArgument("ANALYZER", DSL.literal("standard")), + DSL.namedArgument("ANALYZE_wildcard", DSL.literal("true")), + DSL.namedArgument("Allow_Leading_wildcard", DSL.literal("true")), + DSL.namedArgument("Auto_Generate_Synonyms_Phrase_Query", DSL.literal("true")), + DSL.namedArgument("Boost", DSL.literal("1")) + ).map(arg -> List.of(field, query, arg)); } @ParameterizedTest @MethodSource("generateValidData") - public void test_valid_parameters(List validArgs) { + void test_valid_parameters(List validArgs) { Assertions.assertNotNull(queryStringQuery.build( new QueryStringExpression(validArgs))); } @Test - public void test_SyntaxCheckException_when_no_arguments() { + void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); assertThrows(SyntaxCheckException.class, () -> queryStringQuery.build(new QueryStringExpression(arguments))); } @Test - public void test_SyntaxCheckException_when_one_argument() { + void test_SyntaxCheckException_when_one_argument() { List arguments = List.of(namedArgument("fields", fields_value)); assertThrows(SyntaxCheckException.class, () -> queryStringQuery.build(new QueryStringExpression(arguments))); } @Test - public void test_SemanticCheckException_when_invalid_parameter() { + void test_SemanticCheckException_when_invalid_parameter() { List arguments = List.of( namedArgument("fields", fields_value), namedArgument("query", query_value), @@ -113,11 +111,11 @@ public void test_SemanticCheckException_when_invalid_parameter() { } private NamedArgumentExpression namedArgument(String name, String value) { - return dsl.namedArgument(name, DSL.literal(value)); + return DSL.namedArgument(name, DSL.literal(value)); } private NamedArgumentExpression namedArgument(String name, LiteralExpression value) { - return dsl.namedArgument(name, value); + return DSL.namedArgument(name, value); } private class QueryStringExpression extends FunctionExpression { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java index de8576e9d4..6062a1c62e 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/SimpleQueryStringTest.java @@ -36,8 +36,6 @@ @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class SimpleQueryStringTest { - private static final DSL dsl = new ExpressionConfig() - .dsl(new ExpressionConfig().functionRepository()); private final SimpleQueryStringQuery simpleQueryStringQuery = new SimpleQueryStringQuery(); private final FunctionName simpleQueryString = FunctionName.of("simple_query_string"); private static final LiteralExpression fields_value = DSL.literal( @@ -49,107 +47,107 @@ class SimpleQueryStringTest { static Stream> generateValidData() { return Stream.of( List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("analyze_wildcard", DSL.literal("true")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("analyze_wildcard", DSL.literal("true")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("analyzer", DSL.literal("standard")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("analyzer", DSL.literal("standard")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("auto_generate_synonyms_phrase_query", DSL.literal("true")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("auto_generate_synonyms_phrase_query", DSL.literal("true")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("flags", DSL.literal("PREFIX")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("flags", DSL.literal("PREFIX")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("flags", DSL.literal("PREFIX|NOT|AND")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("flags", DSL.literal("PREFIX|NOT|AND")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("flags", DSL.literal("NOT|AND")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("flags", DSL.literal("NOT|AND")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("flags", DSL.literal("PREFIX|not|AND")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("flags", DSL.literal("PREFIX|not|AND")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("flags", DSL.literal("not|and")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("flags", DSL.literal("not|and")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("fuzzy_max_expansions", DSL.literal("42")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("fuzzy_max_expansions", DSL.literal("42")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("fuzzy_prefix_length", DSL.literal("42")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("fuzzy_prefix_length", DSL.literal("42")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("fuzzy_transpositions", DSL.literal("true")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("fuzzy_transpositions", DSL.literal("true")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("lenient", DSL.literal("true")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("lenient", DSL.literal("true")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("default_operator", DSL.literal("AND")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("default_operator", DSL.literal("AND")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("default_operator", DSL.literal("and")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("default_operator", DSL.literal("and")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("minimum_should_match", DSL.literal("4")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("minimum_should_match", DSL.literal("4")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("quote_field_suffix", DSL.literal(".exact")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("quote_field_suffix", DSL.literal(".exact")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("boost", DSL.literal("1")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("boost", DSL.literal("1")) ), List.of( - dsl.namedArgument("FIELDS", fields_value), - dsl.namedArgument("QUERY", query_value) + DSL.namedArgument("FIELDS", fields_value), + DSL.namedArgument("QUERY", query_value) ), List.of( - dsl.namedArgument("FIELDS", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("ANALYZE_wildcard", DSL.literal("true")) + DSL.namedArgument("FIELDS", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("ANALYZE_wildcard", DSL.literal("true")) ), List.of( - dsl.namedArgument("fields", fields_value), - dsl.namedArgument("query", query_value), - dsl.namedArgument("analyZER", DSL.literal("standard")) + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("analyZER", DSL.literal("standard")) ) ); } @@ -170,7 +168,7 @@ public void test_SyntaxCheckException_when_no_arguments() { @Test public void test_SyntaxCheckException_when_one_argument() { - List arguments = List.of(namedArgument("fields", fields_value)); + List arguments = List.of(DSL.namedArgument("fields", fields_value)); assertThrows(SyntaxCheckException.class, () -> simpleQueryStringQuery.build(new SimpleQueryStringExpression(arguments))); } @@ -178,21 +176,13 @@ public void test_SyntaxCheckException_when_one_argument() { @Test public void test_SemanticCheckException_when_invalid_parameter() { List arguments = List.of( - namedArgument("fields", fields_value), - namedArgument("query", query_value), - namedArgument("unsupported", "unsupported_value")); + DSL.namedArgument("fields", fields_value), + DSL.namedArgument("query", query_value), + DSL.namedArgument("unsupported", "unsupported_value")); Assertions.assertThrows(SemanticCheckException.class, () -> simpleQueryStringQuery.build(new SimpleQueryStringExpression(arguments))); } - private NamedArgumentExpression namedArgument(String name, String value) { - return dsl.namedArgument(name, DSL.literal(value)); - } - - private NamedArgumentExpression namedArgument(String name, LiteralExpression value) { - return dsl.namedArgument(name, value); - } - private class SimpleQueryStringExpression extends FunctionExpression { public SimpleQueryStringExpression(List arguments) { super(SimpleQueryStringTest.this.simpleQueryString, arguments); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java index c50f2efb0d..8052911b81 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java @@ -26,10 +26,9 @@ import org.opensearch.sql.expression.config.ExpressionConfig; class MultiFieldQueryTest { - MultiFieldQuery query; - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + MultiFieldQuery query; private final String testQueryName = "test_query"; - private final Map actionMap + private final Map> actionMap = ImmutableMap.of("paramA", (o, v) -> o); @BeforeEach @@ -49,9 +48,9 @@ void createQueryBuilderTest() { var fieldSpec = ImmutableMap.builder().put(sampleField, ExprValueUtils.floatValue(sampleValue)).build(); - query.createQueryBuilder(List.of(dsl.namedArgument("fields", - new LiteralExpression(ExprTupleValue.fromExprValueMap(fieldSpec))), - dsl.namedArgument("query", + query.createQueryBuilder(List.of(DSL.namedArgument("fields", + new LiteralExpression(ExprTupleValue.fromExprValueMap(fieldSpec))), + DSL.namedArgument("query", new LiteralExpression(ExprValueUtils.stringValue(sampleQuery))))); verify(query).createBuilder(argThat( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java index d6f178b1d6..1814f088a9 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/SingleFieldQueryTest.java @@ -23,7 +23,6 @@ class SingleFieldQueryTest { SingleFieldQuery query; - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); private final String testQueryName = "test_query"; private final Map actionMap = ImmutableMap.of("paramA", (o, v) -> o); @@ -41,9 +40,9 @@ void createQueryBuilderTest() { String sampleQuery = "sample query"; String sampleField = "fieldA"; - query.createQueryBuilder(List.of(dsl.namedArgument("field", + query.createQueryBuilder(List.of(DSL.namedArgument("field", new LiteralExpression(ExprValueUtils.stringValue(sampleField))), - dsl.namedArgument("query", + DSL.namedArgument("query", new LiteralExpression(ExprValueUtils.stringValue(sampleQuery))))); verify(query).createBuilder(eq(sampleField), diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java index 32aa73babe..6b3bb5f545 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java @@ -15,13 +15,11 @@ import org.junit.jupiter.api.Test; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.expression.config.ExpressionConfig; +import org.opensearch.sql.opensearch.OpenSearchTestBase; -class SortQueryBuilderTest { +class SortQueryBuilderTest extends OpenSearchTestBase { - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); - - private SortQueryBuilder sortQueryBuilder = new SortQueryBuilder(); + private final SortQueryBuilder sortQueryBuilder = new SortQueryBuilder(); @Test void build_sortbuilder_from_reference() { diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/serialization/DefaultExpressionSerializerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/serialization/DefaultExpressionSerializerTest.java index 1bec475e04..fddbfa1eed 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/serialization/DefaultExpressionSerializerTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/serialization/DefaultExpressionSerializerTest.java @@ -16,6 +16,7 @@ import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; @@ -23,15 +24,14 @@ import org.opensearch.sql.expression.ExpressionNodeVisitor; import org.opensearch.sql.expression.config.ExpressionConfig; import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.opensearch.OpenSearchTestBase; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -class DefaultExpressionSerializerTest { +class DefaultExpressionSerializerTest extends OpenSearchTestBase { /** * Initialize function repository manually to avoid dependency on Spring container. */ - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); - private final ExpressionSerializer serializer = new DefaultExpressionSerializer(); @Test diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 14ebbe717a..37cce6f923 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -257,11 +257,6 @@ primaryExpression | dataTypeFunctionCall | fieldExpression | literalValue - | constantFunction - ; - -constantFunction - : constantFunctionName LT_PRTHS functionArgs? RT_PRTHS ; booleanExpression @@ -419,17 +414,44 @@ trigonometricFunctionName ; dateAndTimeFunctionBase - : ADDDATE | CONVERT_TZ | DATE | DATE_ADD | DATE_FORMAT | DATE_SUB - | DATETIME | DAY | DAYNAME | DAYOFMONTH | DAYOFWEEK | DAYOFYEAR | FROM_DAYS | FROM_UNIXTIME - | HOUR | MAKEDATE | MAKETIME | MICROSECOND | MINUTE | MONTH | MONTHNAME | PERIOD_ADD - | PERIOD_DIFF | QUARTER | SECOND | SUBDATE | SYSDATE | TIME | TIME_TO_SEC - | TIMESTAMP | TO_DAYS | UNIX_TIMESTAMP | WEEK | YEAR - ; - -// Functions which value could be cached in scope of a single query -constantFunctionName - : CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP | LOCALTIME | LOCALTIMESTAMP | UTC_TIMESTAMP | UTC_DATE | UTC_TIME - | CURDATE | CURTIME | NOW + : datetimeConstantLiteral + | ADDDATE + | CONVERT_TZ + | DATE + | DATE_ADD + | DATE_FORMAT + | DATE_SUB + | DATETIME + | DAY + | DAYNAME + | DAYOFMONTH + | DAYOFWEEK + | DAYOFYEAR + | CURDATE + | CURTIME + | FROM_DAYS + | FROM_UNIXTIME + | HOUR + | MAKEDATE + | MAKETIME + | MICROSECOND + | MINUTE + | MONTH + | MONTHNAME + | NOW + | PERIOD_ADD + | PERIOD_DIFF + | QUARTER + | SECOND + | SUBDATE + | SYSDATE + | TIME + | TIME_TO_SEC + | TIMESTAMP + | TO_DAYS + | UNIX_TIMESTAMP + | WEEK + | YEAR ; /** condition function return boolean value */ @@ -519,6 +541,18 @@ timestampLiteral : TIMESTAMP timestamp=stringLiteral ; +// Actually, these constants are shortcuts to the corresponding functions +datetimeConstantLiteral + : CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | LOCALTIME + | LOCALTIMESTAMP + | UTC_TIMESTAMP + | UTC_DATE + | UTC_TIME + ; + intervalUnit : MICROSECOND | SECOND | MINUTE | HOUR | DAY | WEEK | MONTH | QUARTER | YEAR | SECOND_MICROSECOND | MINUTE_MICROSECOND | MINUTE_SECOND | HOUR_MICROSECOND | HOUR_SECOND | HOUR_MINUTE | DAY_MICROSECOND @@ -563,7 +597,6 @@ keywordsCanBeId | TIMESTAMP | DATE | TIME | FIRST | LAST | timespanUnit | SPAN - | constantFunctionName | dateAndTimeFunctionBase | textFunctionBase | mathematicalFunctionBase diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 4430820081..115bcf3cd8 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -14,7 +14,6 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BySpanClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConstantFunctionContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConvertedDataTypeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DataTypeFunctionCallContext; @@ -23,7 +22,6 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldExpressionContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FunctionArgsContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsQualifiedNameContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsWildcardQualifiedNameContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.InExprContext; @@ -61,7 +59,6 @@ import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.ConstantFunction; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; @@ -232,8 +229,8 @@ public UnresolvedExpression visitTakeAggFunctionCall( @Override public UnresolvedExpression visitBooleanFunctionCall(BooleanFunctionCallContext ctx) { final String functionName = ctx.conditionFunctionBase().getText(); - return visitFunction(FUNCTION_NAME_MAPPING.getOrDefault(functionName, functionName), - ctx.functionArgs()); + return buildFunction(FUNCTION_NAME_MAPPING.getOrDefault(functionName, functionName), + ctx.functionArgs().functionArg()); } /** @@ -241,7 +238,7 @@ public UnresolvedExpression visitBooleanFunctionCall(BooleanFunctionCallContext */ @Override public UnresolvedExpression visitEvalFunctionCall(EvalFunctionCallContext ctx) { - return visitFunction(ctx.evalFunctionName().getText(), ctx.functionArgs()); + return buildFunction(ctx.evalFunctionName().getText(), ctx.functionArgs().functionArg()); } /** @@ -257,23 +254,11 @@ public UnresolvedExpression visitConvertedDataType(ConvertedDataTypeContext ctx) return AstDSL.stringLiteral(ctx.getText()); } - public UnresolvedExpression visitConstantFunction(ConstantFunctionContext ctx) { - return visitConstantFunction(ctx.constantFunctionName().getText(), - ctx.functionArgs()); - } - - private UnresolvedExpression visitConstantFunction(String functionName, - FunctionArgsContext args) { - return new ConstantFunction(functionName, args.functionArg() - .stream() - .map(this::visitFunctionArg) - .collect(Collectors.toList())); - } - - private Function visitFunction(String functionName, FunctionArgsContext args) { + private Function buildFunction(String functionName, + List args) { return new Function( functionName, - args.functionArg() + args .stream() .map(this::visitFunctionArg) .collect(Collectors.toList()) @@ -372,7 +357,7 @@ private QualifiedName visitIdentifiers(List ctx) { } private List singleFieldRelevanceArguments( - OpenSearchPPLParser.SingleFieldRelevanceFunctionContext ctx) { + SingleFieldRelevanceFunctionContext ctx) { // all the arguments are defaulted to string values // to skip environment resolving and function signature resolving ImmutableList.Builder builder = ImmutableList.builder(); @@ -387,7 +372,7 @@ private List singleFieldRelevanceArguments( } private List multiFieldRelevanceArguments( - OpenSearchPPLParser.MultiFieldRelevanceFunctionContext ctx) { + MultiFieldRelevanceFunctionContext ctx) { // all the arguments are defaulted to string values // to skip environment resolving and function signature resolving ImmutableList.Builder builder = ImmutableList.builder(); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstNowLikeFunctionTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstNowLikeFunctionTest.java index 1350305391..711e780f3b 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstNowLikeFunctionTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstNowLikeFunctionTest.java @@ -8,7 +8,6 @@ import static org.junit.Assert.assertEquals; import static org.opensearch.sql.ast.dsl.AstDSL.compare; -import static org.opensearch.sql.ast.dsl.AstDSL.constantFunction; import static org.opensearch.sql.ast.dsl.AstDSL.eval; import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; @@ -18,6 +17,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.relation; import java.util.List; +import org.junit.Assume; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -34,14 +34,11 @@ public class AstNowLikeFunctionTest { * @param name Function name * @param hasFsp Whether function has fsp argument * @param hasShortcut Whether function has shortcut (call without `()`) - * @param isConstantFunction Whether function has constant value */ - public AstNowLikeFunctionTest(String name, Boolean hasFsp, Boolean hasShortcut, - Boolean isConstantFunction) { + public AstNowLikeFunctionTest(String name, Boolean hasFsp, Boolean hasShortcut) { this.name = name; this.hasFsp = hasFsp; this.hasShortcut = hasShortcut; - this.isConstantFunction = isConstantFunction; } /** @@ -51,55 +48,67 @@ public AstNowLikeFunctionTest(String name, Boolean hasFsp, Boolean hasShortcut, @Parameterized.Parameters(name = "{0}") public static Iterable functionNames() { return List.of(new Object[][]{ - {"now", false, false, true}, - {"current_timestamp", false, false, true}, - {"localtimestamp", false, false, true}, - {"localtime", false, false, true}, - {"sysdate", true, false, false}, - {"curtime", false, false, true}, - {"current_time", false, false, true}, - {"curdate", false, false, true}, - {"current_date", false, false, true} + {"now", false, false }, + {"current_timestamp", false, false}, + {"localtimestamp", false, false}, + {"localtime", false, false}, + {"sysdate", true, false}, + {"curtime", false, false}, + {"current_time", false, false}, + {"curdate", false, false}, + {"current_date", false, false} }); } private final String name; - private final Boolean hasFsp; - private final Boolean hasShortcut; - private final Boolean isConstantFunction; + private final boolean hasFsp; + private final boolean hasShortcut; @Test - public void test_now_like_functions() { - for (var call : hasShortcut ? List.of(name, name + "()") : List.of(name + "()")) { - assertEqual("source=t | eval r=" + call, - eval( - relation("t"), - let( - field("r"), - (isConstantFunction ? constantFunction(name) : function(name)) - ) - )); - - assertEqual("search source=t | where a=" + call, - filter( - relation("t"), - compare("=", field("a"), - (isConstantFunction ? constantFunction(name) : function(name))) - ) - ); - } - // Unfortunately, only real functions (not ConstantFunction) might have `fsp` now. - if (hasFsp) { - assertEqual("search source=t | where a=" + name + "(0)", - filter( - relation("t"), - compare("=", field("a"), function(name, intLiteral(0))) - ) - ); - } + public void test_function_call_eval() { + assertEqual( + eval(relation("t"), let(field("r"), function(name))), + "source=t | eval r=" + name + "()" + ); } - protected void assertEqual(String query, Node expectedPlan) { + @Test + public void test_shortcut_eval() { + Assume.assumeTrue(hasShortcut); + assertEqual( + eval(relation("t"), let(field("r"), function(name))), + "source=t | eval r=" + name + ); + } + + @Test + public void test_function_call_where() { + assertEqual( + filter(relation("t"), compare("=", field("a"), function(name))), + "search source=t | where a=" + name + "()" + ); + } + + @Test + public void test_shortcut_where() { + Assume.assumeTrue(hasShortcut); + assertEqual( + filter(relation("t"), compare("=", field("a"), function(name))), + "search source=t | where a=" + name + ); + } + + @Test + public void test_function_call_fsp() { + Assume.assumeTrue(hasFsp); + assertEqual(filter( + relation("t"), + compare("=", field("a"), function(name, intLiteral(0))) + ), "search source=t | where a=" + name + "(0)" + ); + } + + protected void assertEqual(Node expectedPlan, String query) { Node actualPlan = plan(query); assertEquals(expectedPlan, actualPlan); } diff --git a/prometheus/build.gradle b/prometheus/build.gradle index 45a3a4a8ed..774fc5ca7c 100644 --- a/prometheus/build.gradle +++ b/prometheus/build.gradle @@ -25,6 +25,7 @@ dependencies { implementation group: 'org.json', name: 'json', version: '20180813' testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') + testImplementation "org.springframework:spring-test:${spring_version}" testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' testImplementation group: 'org.mockito', name: 'mockito-core', version: '3.12.4' testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '3.12.4' diff --git a/prometheus/src/main/java/org/opensearch/sql/prometheus/functions/resolver/QueryRangeTableFunctionResolver.java b/prometheus/src/main/java/org/opensearch/sql/prometheus/functions/resolver/QueryRangeTableFunctionResolver.java index 63d41fb1d8..a1d72f98c3 100644 --- a/prometheus/src/main/java/org/opensearch/sql/prometheus/functions/resolver/QueryRangeTableFunctionResolver.java +++ b/prometheus/src/main/java/org/opensearch/sql/prometheus/functions/resolver/QueryRangeTableFunctionResolver.java @@ -46,7 +46,7 @@ public Pair resolve(FunctionSignature unreso new FunctionSignature(functionName, List.of(STRING, LONG, LONG, LONG)); final List argumentNames = List.of(QUERY, STARTTIME, ENDTIME, STEP); - FunctionBuilder functionBuilder = arguments -> { + FunctionBuilder functionBuilder = (functionProperties, arguments) -> { Boolean argumentsPassedByName = arguments.stream() .noneMatch(arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); Boolean argumentsPassedByPosition = arguments.stream() diff --git a/prometheus/src/test/java/org/opensearch/sql/prometheus/functions/QueryRangeTableFunctionResolverTest.java b/prometheus/src/test/java/org/opensearch/sql/prometheus/functions/QueryRangeTableFunctionResolverTest.java index caca48f834..06f003c9b6 100644 --- a/prometheus/src/test/java/org/opensearch/sql/prometheus/functions/QueryRangeTableFunctionResolverTest.java +++ b/prometheus/src/test/java/org/opensearch/sql/prometheus/functions/QueryRangeTableFunctionResolverTest.java @@ -26,6 +26,7 @@ import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.function.FunctionBuilder; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; import org.opensearch.sql.expression.function.FunctionSignature; import org.opensearch.sql.expression.function.TableFunctionImplementation; import org.opensearch.sql.prometheus.client.PrometheusClient; @@ -40,6 +41,9 @@ class QueryRangeTableFunctionResolverTest { @Mock private PrometheusClient client; + @Mock + private FunctionProperties functionProperties; + @Test void testResolve() { QueryRangeTableFunctionResolver queryRangeTableFunctionResolver @@ -59,7 +63,7 @@ void testResolve() { assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); FunctionBuilder functionBuilder = resolution.getValue(); TableFunctionImplementation functionImplementation - = (TableFunctionImplementation) functionBuilder.apply(expressions); + = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); assertTrue(functionImplementation instanceof QueryRangeFunctionImplementation); PrometheusMetricTable prometheusMetricTable = (PrometheusMetricTable) functionImplementation.applyArguments(); @@ -93,7 +97,7 @@ void testArgumentsPassedByPosition() { assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); FunctionBuilder functionBuilder = resolution.getValue(); TableFunctionImplementation functionImplementation - = (TableFunctionImplementation) functionBuilder.apply(expressions); + = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); assertTrue(functionImplementation instanceof QueryRangeFunctionImplementation); PrometheusMetricTable prometheusMetricTable = (PrometheusMetricTable) functionImplementation.applyArguments(); @@ -128,7 +132,7 @@ void testArgumentsPassedByNameWithDifferentOrder() { assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); FunctionBuilder functionBuilder = resolution.getValue(); TableFunctionImplementation functionImplementation - = (TableFunctionImplementation) functionBuilder.apply(expressions); + = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); assertTrue(functionImplementation instanceof QueryRangeFunctionImplementation); PrometheusMetricTable prometheusMetricTable = (PrometheusMetricTable) functionImplementation.applyArguments(); @@ -160,7 +164,7 @@ void testMixedArgumentTypes() { assertEquals(functionName, queryRangeTableFunctionResolver.getFunctionName()); assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); SemanticCheckException exception = assertThrows(SemanticCheckException.class, - () -> resolution.getValue().apply(expressions)); + () -> resolution.getValue().apply(functionProperties, expressions)); assertEquals("Arguments should be either passed by name or position", exception.getMessage()); } @@ -182,7 +186,7 @@ void testWrongArgumentsSizeWhenPassedByName() { assertEquals(functionName, queryRangeTableFunctionResolver.getFunctionName()); assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); SemanticCheckException exception = assertThrows(SemanticCheckException.class, - () -> resolution.getValue().apply(expressions)); + () -> resolution.getValue().apply(functionProperties, expressions)); assertEquals("Missing arguments:[endtime,starttime]", exception.getMessage()); } @@ -204,7 +208,7 @@ void testWrongArgumentsSizeWhenPassedByPosition() { assertEquals(functionName, queryRangeTableFunctionResolver.getFunctionName()); assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); SemanticCheckException exception = assertThrows(SemanticCheckException.class, - () -> resolution.getValue().apply(expressions)); + () -> resolution.getValue().apply(functionProperties, expressions)); assertEquals("Missing arguments:[endtime,step]", exception.getMessage()); } diff --git a/prometheus/src/test/java/org/opensearch/sql/prometheus/planner/logical/PrometheusLogicOptimizerTest.java b/prometheus/src/test/java/org/opensearch/sql/prometheus/planner/logical/PrometheusLogicOptimizerTest.java index 7d6d3bed28..c181e0263c 100644 --- a/prometheus/src/test/java/org/opensearch/sql/prometheus/planner/logical/PrometheusLogicOptimizerTest.java +++ b/prometheus/src/test/java/org/opensearch/sql/prometheus/planner/logical/PrometheusLogicOptimizerTest.java @@ -29,11 +29,19 @@ import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; import org.opensearch.sql.storage.Table; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; @ExtendWith(MockitoExtension.class) +@ExtendWith(SpringExtension.class) +@ContextConfiguration(classes = {ExpressionConfig.class}) +@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD) public class PrometheusLogicOptimizerTest { - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); + @Autowired + DSL dsl; @Mock private Table table; diff --git a/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricTableTest.java b/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricTableTest.java index ff5ae5dcf5..8cf9e7c7dd 100644 --- a/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricTableTest.java +++ b/prometheus/src/test/java/org/opensearch/sql/prometheus/storage/PrometheusMetricTableTest.java @@ -56,15 +56,23 @@ import org.opensearch.sql.prometheus.client.PrometheusClient; import org.opensearch.sql.prometheus.constants.TestConstants; import org.opensearch.sql.prometheus.request.PrometheusQueryRequest; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.test.annotation.DirtiesContext; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit.jupiter.SpringExtension; @ExtendWith(MockitoExtension.class) +@ExtendWith(SpringExtension.class) +@ContextConfiguration(classes = {ExpressionConfig.class}) +@DirtiesContext(classMode = DirtiesContext.ClassMode.AFTER_EACH_TEST_METHOD) class PrometheusMetricTableTest { + @Autowired + DSL dsl; + @Mock private PrometheusClient client; - private final DSL dsl = new ExpressionConfig().dsl(new ExpressionConfig().functionRepository()); - @Test @SneakyThrows void testGetFieldTypesFromMetric() { diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index c803f2b5c3..bc7147a2c0 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -239,6 +239,18 @@ timestampLiteral : TIMESTAMP timestamp=stringLiteral ; +// Actually, these constants are shortcuts to the corresponding functions +datetimeConstantLiteral + : CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | LOCALTIME + | LOCALTIMESTAMP + | UTC_TIMESTAMP + | UTC_DATE + | UTC_TIME + ; + intervalLiteral : INTERVAL expression intervalUnit ; @@ -301,12 +313,8 @@ functionCall | aggregateFunction (orderByClause)? filterClause #filteredAggregationFunctionCall | relevanceFunction #relevanceFunctionCall | highlightFunction #highlightFunctionCall - | constantFunction #constantFunctionCall ; -constantFunction - : constantFunctionName LR_BRACKET functionArgs RR_BRACKET - ; highlightFunction : HIGHLIGHT LR_BRACKET relevanceField (COMMA highlightArg)* RR_BRACKET @@ -393,17 +401,44 @@ trigonometricFunctionName ; dateTimeFunctionName - : ADDDATE | CONVERT_TZ | DATE | DATE_ADD | DATE_FORMAT | DATE_SUB - | DATETIME | DAY | DAYNAME | DAYOFMONTH | DAYOFWEEK | DAYOFYEAR | FROM_DAYS | FROM_UNIXTIME - | HOUR | MAKEDATE | MAKETIME | MICROSECOND | MINUTE | MONTH | MONTHNAME | PERIOD_ADD - | PERIOD_DIFF | QUARTER | SECOND | SUBDATE | SYSDATE | TIME | TIME_TO_SEC - | TIMESTAMP | TO_DAYS | UNIX_TIMESTAMP | WEEK | YEAR - ; - -// Functions which value could be cached in scope of a single query -constantFunctionName - : CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP | LOCALTIME | LOCALTIMESTAMP | UTC_TIMESTAMP | UTC_DATE | UTC_TIME - | CURDATE | CURTIME | NOW + : datetimeConstantLiteral + | ADDDATE + | CONVERT_TZ + | CURDATE + | CURTIME + | DATE + | DATE_ADD + | DATE_FORMAT + | DATE_SUB + | DATETIME + | DAY + | DAYNAME + | DAYOFMONTH + | DAYOFWEEK + | DAYOFYEAR + | FROM_DAYS + | FROM_UNIXTIME + | HOUR + | MAKEDATE + | MAKETIME + | MICROSECOND + | MINUTE + | MONTH + | MONTHNAME + | NOW + | PERIOD_ADD + | PERIOD_DIFF + | QUARTER + | SECOND + | SUBDATE + | SYSDATE + | TIME + | TIME_TO_SEC + | TIMESTAMP + | TO_DAYS + | UNIX_TIMESTAMP + | WEEK + | YEAR ; textFunctionName diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 131b6d9116..18efef039f 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -18,7 +18,6 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.CaseFuncAlternativeContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.CaseFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ColumnFilterContext; -import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ConstantFunctionContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ConvertedDataTypeContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.CountStarFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DataTypeFunctionCallContext; @@ -63,7 +62,6 @@ import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; -import org.opensearch.sql.ast.expression.ConstantFunction; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.HighlightFunction; @@ -83,7 +81,6 @@ import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AndExpressionContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ColumnNameContext; -import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FunctionArgsContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IdentContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IntervalLiteralContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NestedExpressionAtomContext; @@ -131,7 +128,7 @@ public UnresolvedExpression visitNestedExpressionAtom(NestedExpressionAtomContex @Override public UnresolvedExpression visitScalarFunctionCall(ScalarFunctionCallContext ctx) { - return visitFunction(ctx.scalarFunctionName().getText(), ctx.functionArgs()); + return buildFunction(ctx.scalarFunctionName().getText(), ctx.functionArgs().functionArg()); } @Override @@ -206,7 +203,7 @@ public UnresolvedExpression visitWindowFunctionClause(WindowFunctionClauseContex @Override public UnresolvedExpression visitScalarWindowFunction(ScalarWindowFunctionContext ctx) { - return visitFunction(ctx.functionName.getText(), ctx.functionArgs()); + return buildFunction(ctx.functionName.getText(), ctx.functionArgs().functionArg()); } @Override @@ -404,30 +401,17 @@ public UnresolvedExpression visitMultiFieldRelevanceFunction( multiFieldRelevanceArguments(ctx)); } - private Function visitFunction(String functionName, FunctionArgsContext args) { + private Function buildFunction(String functionName, + List arg) { return new Function( functionName, - args.functionArg() + arg .stream() .map(this::visitFunctionArg) .collect(Collectors.toList()) ); } - @Override - public UnresolvedExpression visitConstantFunction(ConstantFunctionContext ctx) { - return visitConstantFunction(ctx.constantFunctionName().getText(), - ctx.functionArgs()); - } - - private UnresolvedExpression visitConstantFunction(String functionName, - FunctionArgsContext args) { - return new ConstantFunction(functionName, args.functionArg() - .stream() - .map(this::visitFunctionArg) - .collect(Collectors.toList())); - } - private QualifiedName visitIdentifiers(List identifiers) { return new QualifiedName( identifiers.stream() diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index a955399c4d..2aed4f2834 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -14,7 +14,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.alias; import static org.opensearch.sql.ast.dsl.AstDSL.argument; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; -import static org.opensearch.sql.ast.dsl.AstDSL.constantFunction; import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; @@ -50,12 +49,7 @@ import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.sql.antlr.SQLSyntaxParser; -class AstBuilderTest { - - /** - * SQL syntax parser that helps prepare parse tree as AstBuilder input. - */ - private final SQLSyntaxParser parser = new SQLSyntaxParser(); +class AstBuilderTest extends AstBuilderTestBase { @Test public void can_build_select_literals() { @@ -679,60 +673,6 @@ public void can_build_limit_clause_with_offset() { buildAST("SELECT name FROM test LIMIT 5, 10")); } - private static Stream nowLikeFunctionsData() { - return Stream.of( - Arguments.of("now", false, false, true), - Arguments.of("current_timestamp", false, false, true), - Arguments.of("localtimestamp", false, false, true), - Arguments.of("localtime", false, false, true), - Arguments.of("sysdate", true, false, false), - Arguments.of("curtime", false, false, true), - Arguments.of("current_time", false, false, true), - Arguments.of("curdate", false, false, true), - Arguments.of("current_date", false, false, true) - ); - } - - @ParameterizedTest(name = "{0}") - @MethodSource("nowLikeFunctionsData") - public void test_now_like_functions(String name, Boolean hasFsp, Boolean hasShortcut, - Boolean isConstantFunction) { - for (var call : hasShortcut ? List.of(name, name + "()") : List.of(name + "()")) { - assertEquals( - project( - values(emptyList()), - alias(call, (isConstantFunction ? constantFunction(name) : function(name))) - ), - buildAST("SELECT " + call) - ); - - assertEquals( - project( - filter( - relation("test"), - function( - "=", - qualifiedName("data"), - (isConstantFunction ? constantFunction(name) : function(name))) - ), - AllFields.of() - ), - buildAST("SELECT * FROM test WHERE data = " + call) - ); - } - - // Unfortunately, only real functions (not ConstantFunction) might have `fsp` now. - if (hasFsp) { - assertEquals( - project( - values(emptyList()), - alias(name + "(0)", function(name, intLiteral(0))) - ), - buildAST("SELECT " + name + "(0)") - ); - } - } - @Test public void can_build_qualified_name_highlight() { Map args = new HashMap<>(); @@ -769,8 +709,4 @@ public void can_build_string_literal_highlight() { ); } - private UnresolvedPlan buildAST(String query) { - ParseTree parseTree = parser.parse(query); - return parseTree.accept(new AstBuilder(query)); - } } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTestBase.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTestBase.java new file mode 100644 index 0000000000..2161eb5b1a --- /dev/null +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTestBase.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql.parser; + +import org.antlr.v4.runtime.tree.ParseTree; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.sql.antlr.SQLSyntaxParser; + +public class AstBuilderTestBase { + /** + * SQL syntax parser that helps prepare parse tree as AstBuilder input. + */ + private final SQLSyntaxParser parser = new SQLSyntaxParser(); + + protected UnresolvedPlan buildAST(String query) { + ParseTree parseTree = parser.parse(query); + return parseTree.accept(new AstBuilder(query)); + } +} diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstNowLikeFunctionTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstNowLikeFunctionTest.java new file mode 100644 index 0000000000..19b48ca0bd --- /dev/null +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstNowLikeFunctionTest.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql.parser; + +import static java.util.Collections.emptyList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.sql.ast.dsl.AstDSL.alias; +import static org.opensearch.sql.ast.dsl.AstDSL.filter; +import static org.opensearch.sql.ast.dsl.AstDSL.function; +import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; +import static org.opensearch.sql.ast.dsl.AstDSL.project; +import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; +import static org.opensearch.sql.ast.dsl.AstDSL.relation; +import static org.opensearch.sql.ast.dsl.AstDSL.values; + +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.ast.expression.AllFields; + +class AstNowLikeFunctionTest extends AstBuilderTestBase { + + private static Stream allFunctions() { + return Stream.of("curdate", + "current_date", + "current_time", + "current_timestamp", + "curtime", + "localtimestamp", + "localtime", + "now", + "sysdate") + .map(Arguments::of); + } + + private static Stream supportFsp() { + return Stream.of("sysdate") + .map(Arguments::of); + } + + private static Stream supportShortcut() { + return Stream.of("current_date", + "current_time", + "current_timestamp", + "localtimestamp", + "localtime") + .map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("allFunctions") + void project_call(String name) { + String call = name + "()"; + assertEquals( + project( + values(emptyList()), + alias(call, function(name)) + ), + buildAST("SELECT " + call) + ); + } + + @ParameterizedTest + @MethodSource("allFunctions") + void filter_call(String name) { + String call = name + "()"; + assertEquals( + project( + filter( + relation("test"), + function( + "=", + qualifiedName("data"), + function(name)) + ), + AllFields.of() + ), + buildAST("SELECT * FROM test WHERE data = " + call) + ); + } + + + @ParameterizedTest + @MethodSource("supportFsp") + void fsp(String name) { + assertEquals( + project( + values(emptyList()), + alias(name + "(0)", function(name, intLiteral(0))) + ), + buildAST("SELECT " + name + "(0)") + ); + } +}