diff --git a/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java b/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java index f3fd623371..4704d0566b 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java +++ b/core/src/main/java/org/opensearch/sql/analysis/AnalysisContext.java @@ -7,13 +7,11 @@ package org.opensearch.sql.analysis; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; -import java.util.Map; import java.util.Objects; import lombok.Getter; -import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.function.FunctionProperties; /** * The context used for Analyzer. @@ -26,13 +24,8 @@ public class AnalysisContext { @Getter private final List namedParseExpressions; - /** - * Storage for values of functions which return a constant value. - * We are storing the values there to use it in sequential calls to those functions. - * For example, `now` function should the same value during processing a query. - */ @Getter - private final Map constantFunctionValues; + private final FunctionProperties functionProperties; public AnalysisContext() { this(new TypeEnvironment(null)); @@ -45,7 +38,7 @@ public AnalysisContext() { public AnalysisContext(TypeEnvironment environment) { this.environment = environment; this.namedParseExpressions = new ArrayList<>(); - this.constantFunctionValues = new HashMap<>(); + this.functionProperties = new FunctionProperties(); } /** diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index d463ed424d..228b54ba0c 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -196,7 +196,7 @@ public LogicalPlan visitTableFunction(TableFunction node, AnalysisContext contex .map(unresolvedExpression -> this.expressionAnalyzer.analyze(unresolvedExpression, context)) .collect(Collectors.toList()); TableFunctionImplementation tableFunctionImplementation - = (TableFunctionImplementation) repository.compile( + = (TableFunctionImplementation) repository.compile(context.getFunctionProperties(), dataSourceSchemaIdentifierNameResolver.getDataSourceName(), functionName, arguments); context.push(); TypeEnvironment curEnv = context.peek(); diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index a3ba9b1b6b..719c3adbce 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -25,7 +25,6 @@ import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.ConstantFunction; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; @@ -78,7 +77,8 @@ public class ExpressionAnalyzer extends AbstractNodeVisitor analyze(unresolvedExpression, context)) .collect(Collectors.toList()); - return (Expression) repository.compile(functionName, arguments); + return (Expression) repository.compile(context.getFunctionProperties(), + functionName, arguments); } @SuppressWarnings("unchecked") @@ -237,7 +225,8 @@ public Expression visitCompare(Compare node, AnalysisContext context) { Expression left = analyze(node.getLeft(), context); Expression right = analyze(node.getRight(), context); return (Expression) - repository.compile(functionName, Arrays.asList(left, right)); + repository.compile(context.getFunctionProperties(), + functionName, Arrays.asList(left, right)); } @Override diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java index 433c5fb809..f75bcd5a1d 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizer.java @@ -70,7 +70,8 @@ public Expression visitFunction(FunctionExpression node, AnalysisContext context final List args = node.getArguments().stream().map(expr -> expr.accept(this, context)) .collect(Collectors.toList()); - return (Expression) repository.compile(node.getFunctionName(), args); + return (Expression) repository.compile(context.getFunctionProperties(), + node.getFunctionName(), args); } } diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 53ff93eec1..fe993c899e 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -15,7 +15,6 @@ import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.ConstantFunction; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; @@ -126,10 +125,6 @@ public T visitRelevanceFieldList(RelevanceFieldList node, C context) { return visitChildren(node, context); } - public T visitConstantFunction(ConstantFunction node, C context) { - return visitChildren(node, context); - } - public T visitUnresolvedAttribute(UnresolvedAttribute node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 24ada1fd92..2959cae4a1 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -19,7 +19,6 @@ import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.ConstantFunction; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.EqualTo; import org.opensearch.sql.ast.expression.Field; @@ -246,10 +245,6 @@ public static Function function(String funcName, UnresolvedExpression... funcArg return new Function(funcName, Arrays.asList(funcArgs)); } - public static Function constantFunction(String funcName, UnresolvedExpression... funcArgs) { - return new ConstantFunction(funcName, Arrays.asList(funcArgs)); - } - /** * CASE * WHEN search_condition THEN result_expr diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/ConstantFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/ConstantFunction.java deleted file mode 100644 index f14e65eeb2..0000000000 --- a/core/src/main/java/org/opensearch/sql/ast/expression/ConstantFunction.java +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.ast.expression; - -import java.util.List; -import lombok.EqualsAndHashCode; -import org.opensearch.sql.ast.AbstractNodeVisitor; - -/** - * Expression node that holds a function which should be replaced by its constant[1] value. - * [1] Constant at execution time. - */ -@EqualsAndHashCode(callSuper = false) -public class ConstantFunction extends Function { - - public ConstantFunction(String funcName, List funcArgs) { - super(funcName, funcArgs); - } - - @Override - public R accept(AbstractNodeVisitor nodeVisitor, C context) { - return nodeVisitor.visitConstantFunction(this, context); - } -} diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 6486c4b676..19cd8fd326 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -19,6 +19,7 @@ import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.FunctionImplementation; +import org.opensearch.sql.expression.function.FunctionProperties; import org.opensearch.sql.expression.parse.GrokExpression; import org.opensearch.sql.expression.parse.ParseExpression; import org.opensearch.sql.expression.parse.PatternsExpression; @@ -146,375 +147,375 @@ public static SpanExpression span(Expression field, Expression value, String uni } public static FunctionExpression abs(Expression... expressions) { - return compile(BuiltinFunctionName.ABS, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.ABS, expressions); } public static FunctionExpression ceil(Expression... expressions) { - return compile(BuiltinFunctionName.CEIL, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.CEIL, expressions); } public static FunctionExpression ceiling(Expression... expressions) { - return compile(BuiltinFunctionName.CEILING, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.CEILING, expressions); } public static FunctionExpression conv(Expression... expressions) { - return compile(BuiltinFunctionName.CONV, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.CONV, expressions); } public static FunctionExpression crc32(Expression... expressions) { - return compile(BuiltinFunctionName.CRC32, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.CRC32, expressions); } public static FunctionExpression euler(Expression... expressions) { - return compile(BuiltinFunctionName.E, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.E, expressions); } public static FunctionExpression exp(Expression... expressions) { - return compile(BuiltinFunctionName.EXP, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.EXP, expressions); } public static FunctionExpression floor(Expression... expressions) { - return compile(BuiltinFunctionName.FLOOR, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.FLOOR, expressions); } public static FunctionExpression ln(Expression... expressions) { - return compile(BuiltinFunctionName.LN, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LN, expressions); } public static FunctionExpression log(Expression... expressions) { - return compile(BuiltinFunctionName.LOG, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LOG, expressions); } public static FunctionExpression log10(Expression... expressions) { - return compile(BuiltinFunctionName.LOG10, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LOG10, expressions); } public static FunctionExpression log2(Expression... expressions) { - return compile(BuiltinFunctionName.LOG2, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LOG2, expressions); } public static FunctionExpression mod(Expression... expressions) { - return compile(BuiltinFunctionName.MOD, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.MOD, expressions); } public static FunctionExpression pi(Expression... expressions) { - return compile(BuiltinFunctionName.PI, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.PI, expressions); } public static FunctionExpression pow(Expression... expressions) { - return compile(BuiltinFunctionName.POW, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.POW, expressions); } public static FunctionExpression power(Expression... expressions) { - return compile(BuiltinFunctionName.POWER, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.POWER, expressions); } public static FunctionExpression rand(Expression... expressions) { - return compile(BuiltinFunctionName.RAND, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.RAND, expressions); } public static FunctionExpression round(Expression... expressions) { - return compile(BuiltinFunctionName.ROUND, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.ROUND, expressions); } public static FunctionExpression sign(Expression... expressions) { - return compile(BuiltinFunctionName.SIGN, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.SIGN, expressions); } public static FunctionExpression sqrt(Expression... expressions) { - return compile(BuiltinFunctionName.SQRT, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.SQRT, expressions); } public static FunctionExpression cbrt(Expression... expressions) { - return compile(BuiltinFunctionName.CBRT, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.CBRT, expressions); } public static FunctionExpression truncate(Expression... expressions) { - return compile(BuiltinFunctionName.TRUNCATE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.TRUNCATE, expressions); } public static FunctionExpression acos(Expression... expressions) { - return compile(BuiltinFunctionName.ACOS, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.ACOS, expressions); } public static FunctionExpression asin(Expression... expressions) { - return compile(BuiltinFunctionName.ASIN, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.ASIN, expressions); } public static FunctionExpression atan(Expression... expressions) { - return compile(BuiltinFunctionName.ATAN, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.ATAN, expressions); } public static FunctionExpression atan2(Expression... expressions) { - return compile(BuiltinFunctionName.ATAN2, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.ATAN2, expressions); } public static FunctionExpression cos(Expression... expressions) { - return compile(BuiltinFunctionName.COS, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.COS, expressions); } public static FunctionExpression cot(Expression... expressions) { - return compile(BuiltinFunctionName.COT, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.COT, expressions); } public static FunctionExpression degrees(Expression... expressions) { - return compile(BuiltinFunctionName.DEGREES, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DEGREES, expressions); } public static FunctionExpression radians(Expression... expressions) { - return compile(BuiltinFunctionName.RADIANS, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.RADIANS, expressions); } public static FunctionExpression sin(Expression... expressions) { - return compile(BuiltinFunctionName.SIN, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.SIN, expressions); } public static FunctionExpression tan(Expression... expressions) { - return compile(BuiltinFunctionName.TAN, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.TAN, expressions); } public static FunctionExpression add(Expression... expressions) { - return compile(BuiltinFunctionName.ADD, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.ADD, expressions); } public static FunctionExpression subtract(Expression... expressions) { - return compile(BuiltinFunctionName.SUBTRACT, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.SUBTRACT, expressions); } public static FunctionExpression multiply(Expression... expressions) { - return compile(BuiltinFunctionName.MULTIPLY, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.MULTIPLY, expressions); } public static FunctionExpression adddate(Expression... expressions) { - return compile(BuiltinFunctionName.ADDDATE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.ADDDATE, expressions); } public static FunctionExpression convert_tz(Expression... expressions) { - return compile(BuiltinFunctionName.CONVERT_TZ, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.CONVERT_TZ, expressions); } public static FunctionExpression date(Expression... expressions) { - return compile(BuiltinFunctionName.DATE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DATE, expressions); } public static FunctionExpression datetime(Expression... expressions) { - return compile(BuiltinFunctionName.DATETIME, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DATETIME, expressions); } public static FunctionExpression date_add(Expression... expressions) { - return compile(BuiltinFunctionName.DATE_ADD, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DATE_ADD, expressions); } public static FunctionExpression date_sub(Expression... expressions) { - return compile(BuiltinFunctionName.DATE_SUB, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DATE_SUB, expressions); } public static FunctionExpression day(Expression... expressions) { - return compile(BuiltinFunctionName.DAY, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DAY, expressions); } public static FunctionExpression dayname(Expression... expressions) { - return compile(BuiltinFunctionName.DAYNAME, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DAYNAME, expressions); } public static FunctionExpression dayofmonth(Expression... expressions) { - return compile(BuiltinFunctionName.DAYOFMONTH, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DAYOFMONTH, expressions); } public static FunctionExpression dayofweek(Expression... expressions) { - return compile(BuiltinFunctionName.DAYOFWEEK, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DAYOFWEEK, expressions); } public static FunctionExpression dayofyear(Expression... expressions) { - return compile(BuiltinFunctionName.DAYOFYEAR, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DAYOFYEAR, expressions); } public static FunctionExpression from_days(Expression... expressions) { - return compile(BuiltinFunctionName.FROM_DAYS, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.FROM_DAYS, expressions); } public static FunctionExpression hour(Expression... expressions) { - return compile(BuiltinFunctionName.HOUR, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.HOUR, expressions); } public static FunctionExpression microsecond(Expression... expressions) { - return compile(BuiltinFunctionName.MICROSECOND, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.MICROSECOND, expressions); } public static FunctionExpression minute(Expression... expressions) { - return compile(BuiltinFunctionName.MINUTE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.MINUTE, expressions); } public static FunctionExpression month(Expression... expressions) { - return compile(BuiltinFunctionName.MONTH, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.MONTH, expressions); } public static FunctionExpression monthname(Expression... expressions) { - return compile(BuiltinFunctionName.MONTHNAME, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.MONTHNAME, expressions); } public static FunctionExpression quarter(Expression... expressions) { - return compile(BuiltinFunctionName.QUARTER, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.QUARTER, expressions); } public static FunctionExpression second(Expression... expressions) { - return compile(BuiltinFunctionName.SECOND, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.SECOND, expressions); } public static FunctionExpression subdate(Expression... expressions) { - return compile(BuiltinFunctionName.SUBDATE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.SUBDATE, expressions); } public static FunctionExpression time(Expression... expressions) { - return compile(BuiltinFunctionName.TIME, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.TIME, expressions); } public static FunctionExpression time_to_sec(Expression... expressions) { - return compile(BuiltinFunctionName.TIME_TO_SEC, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.TIME_TO_SEC, expressions); } public static FunctionExpression timestamp(Expression... expressions) { - return compile(BuiltinFunctionName.TIMESTAMP, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.TIMESTAMP, expressions); } public static FunctionExpression date_format(Expression... expressions) { - return compile(BuiltinFunctionName.DATE_FORMAT, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DATE_FORMAT, expressions); } public static FunctionExpression to_days(Expression... expressions) { - return compile(BuiltinFunctionName.TO_DAYS, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.TO_DAYS, expressions); } public static FunctionExpression week(Expression... expressions) { - return compile(BuiltinFunctionName.WEEK, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.WEEK, expressions); } public static FunctionExpression year(Expression... expressions) { - return compile(BuiltinFunctionName.YEAR, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.YEAR, expressions); } public static FunctionExpression divide(Expression... expressions) { - return compile(BuiltinFunctionName.DIVIDE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.DIVIDE, expressions); } public static FunctionExpression module(Expression... expressions) { - return compile(BuiltinFunctionName.MODULES, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.MODULES, expressions); } public static FunctionExpression substr(Expression... expressions) { - return compile(BuiltinFunctionName.SUBSTR, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.SUBSTR, expressions); } public static FunctionExpression substring(Expression... expressions) { - return compile(BuiltinFunctionName.SUBSTR, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.SUBSTR, expressions); } public static FunctionExpression ltrim(Expression... expressions) { - return compile(BuiltinFunctionName.LTRIM, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LTRIM, expressions); } public static FunctionExpression rtrim(Expression... expressions) { - return compile(BuiltinFunctionName.RTRIM, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.RTRIM, expressions); } public static FunctionExpression trim(Expression... expressions) { - return compile(BuiltinFunctionName.TRIM, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.TRIM, expressions); } public static FunctionExpression upper(Expression... expressions) { - return compile(BuiltinFunctionName.UPPER, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.UPPER, expressions); } public static FunctionExpression lower(Expression... expressions) { - return compile(BuiltinFunctionName.LOWER, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LOWER, expressions); } public static FunctionExpression regexp(Expression... expressions) { - return compile(BuiltinFunctionName.REGEXP, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.REGEXP, expressions); } public static FunctionExpression concat(Expression... expressions) { - return compile(BuiltinFunctionName.CONCAT, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.CONCAT, expressions); } public static FunctionExpression concat_ws(Expression... expressions) { - return compile(BuiltinFunctionName.CONCAT_WS, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.CONCAT_WS, expressions); } public static FunctionExpression length(Expression... expressions) { - return compile(BuiltinFunctionName.LENGTH, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LENGTH, expressions); } public static FunctionExpression strcmp(Expression... expressions) { - return compile(BuiltinFunctionName.STRCMP, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.STRCMP, expressions); } public static FunctionExpression right(Expression... expressions) { - return compile(BuiltinFunctionName.RIGHT, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.RIGHT, expressions); } public static FunctionExpression left(Expression... expressions) { - return compile(BuiltinFunctionName.LEFT, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LEFT, expressions); } public static FunctionExpression ascii(Expression... expressions) { - return compile(BuiltinFunctionName.ASCII, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.ASCII, expressions); } public static FunctionExpression locate(Expression... expressions) { - return compile(BuiltinFunctionName.LOCATE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LOCATE, expressions); } public static FunctionExpression replace(Expression... expressions) { - return compile(BuiltinFunctionName.REPLACE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.REPLACE, expressions); } public static FunctionExpression and(Expression... expressions) { - return compile(BuiltinFunctionName.AND, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.AND, expressions); } public static FunctionExpression or(Expression... expressions) { - return compile(BuiltinFunctionName.OR, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.OR, expressions); } public static FunctionExpression xor(Expression... expressions) { - return compile(BuiltinFunctionName.XOR, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.XOR, expressions); } public static FunctionExpression not(Expression... expressions) { - return compile(BuiltinFunctionName.NOT, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.NOT, expressions); } public static FunctionExpression equal(Expression... expressions) { - return compile(BuiltinFunctionName.EQUAL, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.EQUAL, expressions); } public static FunctionExpression notequal(Expression... expressions) { - return compile(BuiltinFunctionName.NOTEQUAL, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.NOTEQUAL, expressions); } public static FunctionExpression less(Expression... expressions) { - return compile(BuiltinFunctionName.LESS, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LESS, expressions); } public static FunctionExpression lte(Expression... expressions) { - return compile(BuiltinFunctionName.LTE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LTE, expressions); } public static FunctionExpression greater(Expression... expressions) { - return compile(BuiltinFunctionName.GREATER, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.GREATER, expressions); } public static FunctionExpression gte(Expression... expressions) { - return compile(BuiltinFunctionName.GTE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.GTE, expressions); } public static FunctionExpression like(Expression... expressions) { - return compile(BuiltinFunctionName.LIKE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.LIKE, expressions); } public static FunctionExpression notLike(Expression... expressions) { - return compile(BuiltinFunctionName.NOT_LIKE, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.NOT_LIKE, expressions); } public static Aggregator avg(Expression... expressions) { @@ -554,15 +555,15 @@ public static Aggregator take(Expression... expressions) { } public static RankingWindowFunction rowNumber() { - return compile(BuiltinFunctionName.ROW_NUMBER); + return compile(FunctionProperties.None, BuiltinFunctionName.ROW_NUMBER); } public static RankingWindowFunction rank() { - return compile(BuiltinFunctionName.RANK); + return compile(FunctionProperties.None, BuiltinFunctionName.RANK); } public static RankingWindowFunction denseRank() { - return compile(BuiltinFunctionName.DENSE_RANK); + return compile(FunctionProperties.None, BuiltinFunctionName.DENSE_RANK); } public static Aggregator min(Expression... expressions) { @@ -574,31 +575,31 @@ public static Aggregator max(Expression... expressions) { } private static Aggregator aggregate(BuiltinFunctionName functionName, Expression... expressions) { - return compile(functionName, expressions); + return compile(FunctionProperties.None, functionName, expressions); } public static FunctionExpression isnull(Expression... expressions) { - return compile(BuiltinFunctionName.ISNULL, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.ISNULL, expressions); } public static FunctionExpression is_null(Expression... expressions) { - return compile(BuiltinFunctionName.IS_NULL, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.IS_NULL, expressions); } public static FunctionExpression isnotnull(Expression... expressions) { - return compile(BuiltinFunctionName.IS_NOT_NULL, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.IS_NOT_NULL, expressions); } public static FunctionExpression ifnull(Expression... expressions) { - return compile(BuiltinFunctionName.IFNULL, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.IFNULL, expressions); } public static FunctionExpression nullif(Expression... expressions) { - return compile(BuiltinFunctionName.NULLIF, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.NULLIF, expressions); } public static FunctionExpression iffunction(Expression... expressions) { - return compile(BuiltinFunctionName.IF, expressions); + return compile(FunctionProperties.None, BuiltinFunctionName.IF, expressions); } public static Expression cases(Expression defaultResult, @@ -611,132 +612,143 @@ public static WhenClause when(Expression condition, Expression result) { } public static FunctionExpression interval(Expression value, Expression unit) { - return compile(BuiltinFunctionName.INTERVAL, value, unit); + return compile(FunctionProperties.None, BuiltinFunctionName.INTERVAL, value, unit); } public static FunctionExpression castString(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_STRING, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_STRING, value); } public static FunctionExpression castByte(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_BYTE, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_BYTE, value); } public static FunctionExpression castShort(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_SHORT, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_SHORT, value); } public static FunctionExpression castInt(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_INT, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_INT, value); } public static FunctionExpression castLong(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_LONG, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_LONG, value); } public static FunctionExpression castFloat(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_FLOAT, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_FLOAT, value); } public static FunctionExpression castDouble(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_DOUBLE, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_DOUBLE, value); } public static FunctionExpression castBoolean(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_BOOLEAN, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_BOOLEAN, value); } public static FunctionExpression castDate(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_DATE, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_DATE, value); } public static FunctionExpression castTime(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_TIME, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_TIME, value); } public static FunctionExpression castTimestamp(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_TIMESTAMP, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_TIMESTAMP, value); } public static FunctionExpression castDatetime(Expression value) { - return compile(BuiltinFunctionName.CAST_TO_DATETIME, value); + return compile(FunctionProperties.None, BuiltinFunctionName.CAST_TO_DATETIME, value); } public static FunctionExpression typeof(Expression value) { - return compile(BuiltinFunctionName.TYPEOF, value); + return compile(FunctionProperties.None, BuiltinFunctionName.TYPEOF, value); } public static FunctionExpression match(Expression... args) { - return compile(BuiltinFunctionName.MATCH, args); + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH, args); } public static FunctionExpression match_phrase(Expression... args) { - return compile(BuiltinFunctionName.MATCH_PHRASE, args); + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_PHRASE, args); } public static FunctionExpression match_phrase_prefix(Expression... args) { - return compile(BuiltinFunctionName.MATCH_PHRASE_PREFIX, args); + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_PHRASE_PREFIX, args); } public static FunctionExpression multi_match(Expression... args) { - return compile(BuiltinFunctionName.MULTI_MATCH, args); + return compile(FunctionProperties.None, BuiltinFunctionName.MULTI_MATCH, args); } public static FunctionExpression simple_query_string(Expression... args) { - return compile(BuiltinFunctionName.SIMPLE_QUERY_STRING, args); + return compile(FunctionProperties.None, BuiltinFunctionName.SIMPLE_QUERY_STRING, args); } public static FunctionExpression query(Expression... args) { - return compile(BuiltinFunctionName.QUERY, args); + return compile(FunctionProperties.None, BuiltinFunctionName.QUERY, args); } public static FunctionExpression query_string(Expression... args) { - return compile(BuiltinFunctionName.QUERY_STRING, args); + return compile(FunctionProperties.None, BuiltinFunctionName.QUERY_STRING, args); } public static FunctionExpression match_bool_prefix(Expression... args) { - return compile(BuiltinFunctionName.MATCH_BOOL_PREFIX, args); + return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_BOOL_PREFIX, args); } - public static FunctionExpression now(Expression... args) { - return compile(BuiltinFunctionName.NOW, args); + public static FunctionExpression now(FunctionProperties functionProperties, + Expression... args) { + return compile(functionProperties, BuiltinFunctionName.NOW, args); } - public static FunctionExpression current_timestamp(Expression... args) { - return compile(BuiltinFunctionName.CURRENT_TIMESTAMP, args); + public static FunctionExpression current_timestamp(FunctionProperties functionProperties, + Expression... args) { + return compile(functionProperties, BuiltinFunctionName.CURRENT_TIMESTAMP, args); } - public static FunctionExpression localtimestamp(Expression... args) { - return compile(BuiltinFunctionName.LOCALTIMESTAMP, args); + public static FunctionExpression localtimestamp(FunctionProperties functionProperties, + Expression... args) { + return compile(functionProperties, BuiltinFunctionName.LOCALTIMESTAMP, args); } - public static FunctionExpression localtime(Expression... args) { - return compile(BuiltinFunctionName.LOCALTIME, args); + public static FunctionExpression localtime(FunctionProperties functionProperties, + Expression... args) { + return compile(functionProperties, BuiltinFunctionName.LOCALTIME, args); } - public static FunctionExpression sysdate(Expression... args) { - return compile(BuiltinFunctionName.SYSDATE, args); + public static FunctionExpression sysdate(FunctionProperties functionProperties, + Expression... args) { + return compile(functionProperties, BuiltinFunctionName.SYSDATE, args); } - public static FunctionExpression curtime(Expression... args) { - return compile(BuiltinFunctionName.CURTIME, args); + public static FunctionExpression curtime(FunctionProperties functionProperties, + Expression... args) { + return compile(functionProperties, BuiltinFunctionName.CURTIME, args); } - public static FunctionExpression current_time(Expression... args) { - return compile(BuiltinFunctionName.CURRENT_TIME, args); + public static FunctionExpression current_time(FunctionProperties functionProperties, + Expression... args) { + return compile(functionProperties, BuiltinFunctionName.CURRENT_TIME, args); } - public static FunctionExpression curdate(Expression... args) { - return compile(BuiltinFunctionName.CURDATE, args); + public static FunctionExpression curdate(FunctionProperties functionProperties, + Expression... args) { + return compile(functionProperties, BuiltinFunctionName.CURDATE, args); } - public static FunctionExpression current_date(Expression... args) { - return compile(BuiltinFunctionName.CURRENT_DATE, args); + public static FunctionExpression current_date(FunctionProperties functionProperties, + Expression... args) { + return compile(functionProperties, BuiltinFunctionName.CURRENT_DATE, args); } @SuppressWarnings("unchecked") private static - T compile(BuiltinFunctionName bfn, Expression... args) { - return (T) BuiltinFunctionRepository.getInstance().compile(bfn.getName(), Arrays.asList(args)); + T compile(FunctionProperties functionProperties, + BuiltinFunctionName bfn, Expression... args) { + return (T) BuiltinFunctionRepository.getInstance().compile(functionProperties, + bfn.getName(), Arrays.asList(args)); } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index 9fbf1557aa..c30ca13351 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -68,7 +68,7 @@ private static DefaultFunctionResolver avg() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> new AvgAggregator(arguments, DOUBLE)) + (functionProperties, arguments) -> new AvgAggregator(arguments, DOUBLE)) .build() ); } @@ -78,7 +78,7 @@ private static DefaultFunctionResolver count() { DefaultFunctionResolver functionResolver = new DefaultFunctionResolver(functionName, ExprCoreType.coreTypes().stream().collect(Collectors.toMap( type -> new FunctionSignature(functionName, Collections.singletonList(type)), - type -> arguments -> new CountAggregator(arguments, INTEGER)))); + type -> (functionProperties, arguments) -> new CountAggregator(arguments, INTEGER)))); return functionResolver; } @@ -88,13 +88,13 @@ private static DefaultFunctionResolver sum() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), - arguments -> new SumAggregator(arguments, INTEGER)) + (functionProperties, arguments) -> new SumAggregator(arguments, INTEGER)) .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), - arguments -> new SumAggregator(arguments, LONG)) + (functionProperties, arguments) -> new SumAggregator(arguments, LONG)) .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), - arguments -> new SumAggregator(arguments, FLOAT)) + (functionProperties, arguments) -> new SumAggregator(arguments, FLOAT)) .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> new SumAggregator(arguments, DOUBLE)) + (functionProperties, arguments) -> new SumAggregator(arguments, DOUBLE)) .build() ); } @@ -105,23 +105,23 @@ private static DefaultFunctionResolver min() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), - arguments -> new MinAggregator(arguments, INTEGER)) + (functionProperties, arguments) -> new MinAggregator(arguments, INTEGER)) .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), - arguments -> new MinAggregator(arguments, LONG)) + (functionProperties, arguments) -> new MinAggregator(arguments, LONG)) .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), - arguments -> new MinAggregator(arguments, FLOAT)) + (functionProperties, arguments) -> new MinAggregator(arguments, FLOAT)) .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> new MinAggregator(arguments, DOUBLE)) + (functionProperties, arguments) -> new MinAggregator(arguments, DOUBLE)) .put(new FunctionSignature(functionName, Collections.singletonList(STRING)), - arguments -> new MinAggregator(arguments, STRING)) + (functionProperties, arguments) -> new MinAggregator(arguments, STRING)) .put(new FunctionSignature(functionName, Collections.singletonList(DATE)), - arguments -> new MinAggregator(arguments, DATE)) + (functionProperties, arguments) -> new MinAggregator(arguments, DATE)) .put(new FunctionSignature(functionName, Collections.singletonList(DATETIME)), - arguments -> new MinAggregator(arguments, DATETIME)) + (functionProperties, arguments) -> new MinAggregator(arguments, DATETIME)) .put(new FunctionSignature(functionName, Collections.singletonList(TIME)), - arguments -> new MinAggregator(arguments, TIME)) + (functionProperties, arguments) -> new MinAggregator(arguments, TIME)) .put(new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), - arguments -> new MinAggregator(arguments, TIMESTAMP)) + (functionProperties, arguments) -> new MinAggregator(arguments, TIMESTAMP)) .build()); } @@ -131,23 +131,23 @@ private static DefaultFunctionResolver max() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(INTEGER)), - arguments -> new MaxAggregator(arguments, INTEGER)) + (functionProperties, arguments) -> new MaxAggregator(arguments, INTEGER)) .put(new FunctionSignature(functionName, Collections.singletonList(LONG)), - arguments -> new MaxAggregator(arguments, LONG)) + (functionProperties, arguments) -> new MaxAggregator(arguments, LONG)) .put(new FunctionSignature(functionName, Collections.singletonList(FLOAT)), - arguments -> new MaxAggregator(arguments, FLOAT)) + (functionProperties, arguments) -> new MaxAggregator(arguments, FLOAT)) .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> new MaxAggregator(arguments, DOUBLE)) + (functionProperties, arguments) -> new MaxAggregator(arguments, DOUBLE)) .put(new FunctionSignature(functionName, Collections.singletonList(STRING)), - arguments -> new MaxAggregator(arguments, STRING)) + (functionProperties, arguments) -> new MaxAggregator(arguments, STRING)) .put(new FunctionSignature(functionName, Collections.singletonList(DATE)), - arguments -> new MaxAggregator(arguments, DATE)) + (functionProperties, arguments) -> new MaxAggregator(arguments, DATE)) .put(new FunctionSignature(functionName, Collections.singletonList(DATETIME)), - arguments -> new MaxAggregator(arguments, DATETIME)) + (functionProperties, arguments) -> new MaxAggregator(arguments, DATETIME)) .put(new FunctionSignature(functionName, Collections.singletonList(TIME)), - arguments -> new MaxAggregator(arguments, TIME)) + (functionProperties, arguments) -> new MaxAggregator(arguments, TIME)) .put(new FunctionSignature(functionName, Collections.singletonList(TIMESTAMP)), - arguments -> new MaxAggregator(arguments, TIMESTAMP)) + (functionProperties, arguments) -> new MaxAggregator(arguments, TIMESTAMP)) .build() ); } @@ -158,7 +158,7 @@ private static DefaultFunctionResolver varSamp() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> varianceSample(arguments, DOUBLE)) + (functionProperties, arguments) -> varianceSample(arguments, DOUBLE)) .build() ); } @@ -169,7 +169,7 @@ private static DefaultFunctionResolver varPop() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> variancePopulation(arguments, DOUBLE)) + (functionProperties, arguments) -> variancePopulation(arguments, DOUBLE)) .build() ); } @@ -180,7 +180,7 @@ private static DefaultFunctionResolver stddevSamp() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> stddevSample(arguments, DOUBLE)) + (functionProperties, arguments) -> stddevSample(arguments, DOUBLE)) .build() ); } @@ -191,7 +191,7 @@ private static DefaultFunctionResolver stddevPop() { functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, Collections.singletonList(DOUBLE)), - arguments -> stddevPopulation(arguments, DOUBLE)) + (functionProperties, arguments) -> stddevPopulation(arguments, DOUBLE)) .build() ); } @@ -201,7 +201,7 @@ private static DefaultFunctionResolver take() { DefaultFunctionResolver functionResolver = new DefaultFunctionResolver(functionName, new ImmutableMap.Builder() .put(new FunctionSignature(functionName, ImmutableList.of(STRING, INTEGER)), - arguments -> new TakeAggregator(arguments, ARRAY)) + (functionProperties, arguments) -> new TakeAggregator(arguments, ARRAY)) .build()); return functionResolver; } diff --git a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java index 42274c0756..0d194e7197 100644 --- a/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/datetime/DateTimeFunction.java @@ -18,6 +18,7 @@ import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; import static org.opensearch.sql.expression.function.FunctionDSL.define; import static org.opensearch.sql.expression.function.FunctionDSL.impl; +import static org.opensearch.sql.expression.function.FunctionDSL.implWithProperties; import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling; import static org.opensearch.sql.utils.DateTimeFormatters.DATE_FORMATTER_LONG_YEAR; import static org.opensearch.sql.utils.DateTimeFormatters.DATE_FORMATTER_SHORT_YEAR; @@ -28,6 +29,7 @@ import java.math.BigDecimal; import java.math.RoundingMode; import java.text.DecimalFormat; +import java.time.Clock; import java.time.DateTimeException; import java.time.Instant; import java.time.LocalDate; @@ -42,7 +44,6 @@ import java.util.Locale; import java.util.TimeZone; import java.util.concurrent.TimeUnit; -import javax.annotation.Nullable; import lombok.experimental.UtilityClass; import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprDatetimeValue; @@ -60,7 +61,9 @@ import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; import org.opensearch.sql.expression.function.DefaultFunctionResolver; +import org.opensearch.sql.expression.function.FunctionDSL; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionResolver; import org.opensearch.sql.utils.DateTimeUtils; /** @@ -128,6 +131,78 @@ public void register(BuiltinFunctionRepository repository) { repository.register(year()); } + /** + * NOW() returns a constant time that indicates the time at which the statement began to execute. + * `fsp` argument support is removed until refactoring to avoid bug where `now()`, `now(x)` and + * `now(y) return different values. + */ + private FunctionResolver now(FunctionName functionName) { + return define(functionName, + implWithProperties( + functionProperties -> new ExprDatetimeValue( + formatNow(functionProperties.getQueryStartClock())), DATETIME) + ); + } + + private FunctionResolver now() { + return now(BuiltinFunctionName.NOW.getName()); + } + + private FunctionResolver current_timestamp() { + return now(BuiltinFunctionName.CURRENT_TIMESTAMP.getName()); + } + + private FunctionResolver localtimestamp() { + return now(BuiltinFunctionName.LOCALTIMESTAMP.getName()); + } + + private FunctionResolver localtime() { + return now(BuiltinFunctionName.LOCALTIME.getName()); + } + + /** + * SYSDATE() returns the time at which it executes. + */ + private FunctionResolver sysdate() { + return define(BuiltinFunctionName.SYSDATE.getName(), + implWithProperties(functionProperties + -> new ExprDatetimeValue(formatNow(Clock.systemDefaultZone())), DATETIME), + FunctionDSL.implWithProperties((functionProperties, v) -> new ExprDatetimeValue( + formatNow(Clock.systemDefaultZone(), v.integerValue())), DATETIME, INTEGER) + ); + } + + /** + * Synonym for @see `now`. + */ + private FunctionResolver curtime(FunctionName functionName) { + return define(functionName, + implWithProperties(functionProperties -> new ExprTimeValue( + formatNow(functionProperties.getQueryStartClock()).toLocalTime()), TIME)); + } + + private FunctionResolver curtime() { + return curtime(BuiltinFunctionName.CURTIME.getName()); + } + + private FunctionResolver current_time() { + return curtime(BuiltinFunctionName.CURRENT_TIME.getName()); + } + + private FunctionResolver curdate(FunctionName functionName) { + return define(functionName, + implWithProperties(functionProperties -> new ExprDateValue( + formatNow(functionProperties.getQueryStartClock()).toLocalDate()), DATE)); + } + + private FunctionResolver curdate() { + return curdate(BuiltinFunctionName.CURDATE.getName()); + } + + private FunctionResolver current_date() { + return curdate(BuiltinFunctionName.CURRENT_DATE.getName()); + } + /** * Specify a start date and add a temporal amount to the date. * The return type depends on the date type and the interval unit. Detailed supported signatures: @@ -170,41 +245,6 @@ private DefaultFunctionResolver convert_tz() { ); } - private DefaultFunctionResolver curdate(FunctionName functionName) { - return define(functionName, - impl(() -> new ExprDateValue(formatNow(null).toLocalDate()), DATE) - ); - } - - private DefaultFunctionResolver curdate() { - return curdate(BuiltinFunctionName.CURDATE.getName()); - } - - /** - * Synonym for @see `now`. - */ - private DefaultFunctionResolver curtime(FunctionName functionName) { - return define(functionName, - impl(() -> new ExprTimeValue(formatNow(null).toLocalTime()), TIME) - ); - } - - private DefaultFunctionResolver curtime() { - return curtime(BuiltinFunctionName.CURTIME.getName()); - } - - private DefaultFunctionResolver current_date() { - return curdate(BuiltinFunctionName.CURRENT_DATE.getName()); - } - - private DefaultFunctionResolver current_time() { - return curtime(BuiltinFunctionName.CURRENT_TIME.getName()); - } - - private DefaultFunctionResolver current_timestamp() { - return now(BuiltinFunctionName.CURRENT_TIMESTAMP.getName()); - } - /** * Extracts the date part of a date and time value. * Also to construct a date type. The supported signatures: @@ -224,7 +264,7 @@ private DefaultFunctionResolver date() { * (STRING, STRING) -> DATETIME * (STRING) -> DATETIME */ - private DefaultFunctionResolver datetime() { + private FunctionResolver datetime() { return define(BuiltinFunctionName.DATETIME.getName(), impl(nullMissingHandling(DateTimeFunction::exprDateTime), DATETIME, STRING, STRING), @@ -336,7 +376,7 @@ private DefaultFunctionResolver from_days() { impl(nullMissingHandling(DateTimeFunction::exprFromDays), DATE, LONG)); } - private DefaultFunctionResolver from_unixtime() { + private FunctionResolver from_unixtime() { return define(BuiltinFunctionName.FROM_UNIXTIME.getName(), impl(nullMissingHandling(DateTimeFunction::exprFromUnixTime), DATETIME, DOUBLE), impl(nullMissingHandling(DateTimeFunction::exprFromUnixTimeFormat), @@ -355,35 +395,12 @@ private DefaultFunctionResolver hour() { ); } - private DefaultFunctionResolver localtime() { - return now(BuiltinFunctionName.LOCALTIME.getName()); - } - - private DefaultFunctionResolver localtimestamp() { - return now(BuiltinFunctionName.LOCALTIMESTAMP.getName()); - } - - /** - * NOW() returns a constant time that indicates the time at which the statement began to execute. - * `fsp` argument support is removed until refactoring to avoid bug where `now()`, `now(x)` and - * `now(y) return different values. - */ - private DefaultFunctionResolver now(FunctionName functionName) { - return define(functionName, - impl(() -> new ExprDatetimeValue(formatNow(null)), DATETIME) - ); - } - - private DefaultFunctionResolver now() { - return now(BuiltinFunctionName.NOW.getName()); - } - - private DefaultFunctionResolver makedate() { + private FunctionResolver makedate() { return define(BuiltinFunctionName.MAKEDATE.getName(), impl(nullMissingHandling(DateTimeFunction::exprMakeDate), DATE, DOUBLE, DOUBLE)); } - private DefaultFunctionResolver maketime() { + private FunctionResolver maketime() { return define(BuiltinFunctionName.MAKETIME.getName(), impl(nullMissingHandling(DateTimeFunction::exprMakeTime), TIME, DOUBLE, DOUBLE, DOUBLE)); } @@ -535,9 +552,10 @@ private DefaultFunctionResolver to_days() { impl(nullMissingHandling(DateTimeFunction::exprToDays), LONG, DATETIME)); } - private DefaultFunctionResolver unix_timestamp() { + private FunctionResolver unix_timestamp() { return define(BuiltinFunctionName.UNIX_TIMESTAMP.getName(), - impl(DateTimeFunction::unixTimeStamp, LONG), + implWithProperties(functionProperties + -> DateTimeFunction.unixTimeStamp(functionProperties.getQueryStartClock()), LONG), impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, DATE), impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, DATETIME), impl(nullMissingHandling(DateTimeFunction::unixTimeStampOf), DOUBLE, TIMESTAMP), @@ -1004,16 +1022,6 @@ private ExprValue exprSubDateInterval(ExprValue date, ExprValue expr) { : exprValue); } - /** - * SYSDATE() returns the time at which it executes. - */ - private DefaultFunctionResolver sysdate() { - return define(BuiltinFunctionName.SYSDATE.getName(), - impl(() -> new ExprDatetimeValue(formatNow(null)), DATETIME), - impl((v) -> new ExprDatetimeValue(formatNow(v.integerValue())), DATETIME, INTEGER) - ); - } - /** * Time implementation for ExprValue. * @@ -1073,8 +1081,8 @@ private ExprValue exprWeek(ExprValue date, ExprValue mode) { CalendarLookup.getWeekNumber(mode.integerValue(), date.dateValue())); } - private ExprValue unixTimeStamp() { - return new ExprLongValue(Instant.now().getEpochSecond()); + private ExprValue unixTimeStamp(Clock clock) { + return new ExprLongValue(Instant.now(clock).getEpochSecond()); } private ExprValue unixTimeStampOf(ExprValue value) { @@ -1167,17 +1175,18 @@ private ExprValue exprYear(ExprValue date) { return new ExprIntegerValue(date.dateValue().getYear()); } + private LocalDateTime formatNow(Clock clock) { + return formatNow(clock, 0); + } + /** * Prepare LocalDateTime value. Truncate fractional second part according to the argument. * @param fsp argument is given to specify a fractional seconds precision from 0 to 6, * the return value includes a fractional seconds part of that many digits. * @return LocalDateTime object. */ - private LocalDateTime formatNow(@Nullable Integer fsp) { - var res = LocalDateTime.now(); - if (fsp == null) { - fsp = 0; - } + private LocalDateTime formatNow(Clock clock, Integer fsp) { + var res = LocalDateTime.now(clock); var defaultPrecision = 9; // There are 10^9 nanoseconds in one second if (fsp < 0 || fsp > 6) { // Check that the argument is in the allowed range [0, 6] throw new IllegalArgumentException( diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java index 22c588e679..71fd19991e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionRepository.java @@ -9,11 +9,11 @@ import static org.opensearch.sql.ast.expression.Cast.isCastFunction; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableList; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.common.utils.StringUtils; @@ -40,7 +40,6 @@ * */ public class BuiltinFunctionRepository { - public static final String DEFAULT_NAMESPACE = "default"; private final Map> namespaceFunctionResolverMap; @@ -111,13 +110,13 @@ public void register(String namespace, FunctionResolver resolver) { namespaceFunctionResolverMap.get(namespace).put(resolver.getFunctionName(), resolver); } - /** * Compile FunctionExpression under default namespace. * */ - public FunctionImplementation compile(FunctionName functionName, List expressions) { - return compile(DEFAULT_NAMESPACE, functionName, expressions); + public FunctionImplementation compile(FunctionProperties functionProperties, + FunctionName functionName, List expressions) { + return compile(functionProperties, DEFAULT_NAMESPACE, functionName, expressions); } @@ -125,16 +124,18 @@ public FunctionImplementation compile(FunctionName functionName, List expressions) { List namespaceList = new ArrayList<>(List.of(DEFAULT_NAMESPACE)); if (!namespace.equals(DEFAULT_NAMESPACE)) { namespaceList.add(namespace); } - FunctionBuilder resolvedFunctionBuilder = resolve(namespaceList, - new FunctionSignature(functionName, expressions - .stream().map(expression -> expression.type()).collect(Collectors.toList()))); - return resolvedFunctionBuilder.apply(expressions); + FunctionBuilder resolvedFunctionBuilder = resolve( + namespaceList, new FunctionSignature(functionName, expressions + .stream().map(Expression::type).collect(Collectors.toList()))); + return resolvedFunctionBuilder.apply(functionProperties, expressions); } /** @@ -144,11 +145,12 @@ public FunctionImplementation compile(String namespace, FunctionName functionNam * So list of namespaces is also the priority of namespaces. * * @param functionSignature {@link FunctionSignature} functionsignature. - * * @return Original function builder if it's a cast function or all arguments have expected types - * or other wise wrap its arguments by cast function as needed. + * or otherwise wrap its arguments by cast function as needed. */ - public FunctionBuilder resolve(List namespaces, FunctionSignature functionSignature) { + public FunctionBuilder + resolve(List namespaces, + FunctionSignature functionSignature) { FunctionName functionName = functionSignature.getFunctionName(); FunctionBuilder result = null; for (String namespace : namespaces) { @@ -167,9 +169,10 @@ public FunctionBuilder resolve(List namespaces, FunctionSignature functi } } - private FunctionBuilder getFunctionBuilder(FunctionSignature functionSignature, - FunctionName functionName, - Map functionResolverMap) { + private FunctionBuilder getFunctionBuilder( + FunctionSignature functionSignature, + FunctionName functionName, + Map functionResolverMap) { Pair resolvedSignature = functionResolverMap.get(functionName).resolve(functionSignature); @@ -179,7 +182,8 @@ private FunctionBuilder getFunctionBuilder(FunctionSignature functionSignature, if (isCastFunction(functionName) || sourceTypes.equals(targetTypes)) { return funcBuilder; } - return castArguments(sourceTypes, targetTypes, funcBuilder); + return castArguments(sourceTypes, + targetTypes, funcBuilder); } /** @@ -191,7 +195,7 @@ private FunctionBuilder getFunctionBuilder(FunctionSignature functionSignature, private FunctionBuilder castArguments(List sourceTypes, List targetTypes, FunctionBuilder funcBuilder) { - return arguments -> { + return (fp, arguments) -> { List argsCasted = new ArrayList<>(); for (int i = 0; i < arguments.size(); i++) { Expression arg = arguments.get(i); @@ -199,12 +203,12 @@ private FunctionBuilder castArguments(List sourceTypes, ExprType targetType = targetTypes.get(i); if (isCastRequired(sourceType, targetType)) { - argsCasted.add(cast(arg, targetType)); + argsCasted.add(cast(arg, targetType).apply(fp)); } else { argsCasted.add(arg); } } - return funcBuilder.apply(argsCasted); + return funcBuilder.apply(fp, argsCasted); }; } @@ -217,13 +221,13 @@ private boolean isCastRequired(ExprType sourceType, ExprType targetType) { return sourceType.shouldCast(targetType); } - private Expression cast(Expression arg, ExprType targetType) { + private Function cast(Expression arg, ExprType targetType) { FunctionName castFunctionName = getCastFunctionName(targetType); if (castFunctionName == null) { throw new ExpressionEvaluationException(StringUtils.format( "Type conversion to type %s is not supported", targetType)); } - return (Expression) compile(castFunctionName, ImmutableList.of(arg)); + return functionProperties -> (Expression) compile(functionProperties, + castFunctionName, List.of(arg)); } - } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java index aa3077051b..b6e32a1d27 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionBuilder.java @@ -18,8 +18,9 @@ public interface FunctionBuilder { /** * Create {@link FunctionImplementation} from input {@link Expression} list. * - * @param arguments {@link Expression} list + * @param functionProperties context for function execution. + * @param arguments {@link Expression} list. * @return {@link FunctionImplementation} */ - FunctionImplementation apply(List arguments); + FunctionImplementation apply(FunctionProperties functionProperties, List arguments); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java index 1fad333ead..5b182f76f4 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionDSL.java @@ -19,6 +19,7 @@ import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.DefaultFunctionResolver.DefaultFunctionResolverBuilder; /** * Function Define Utility. @@ -35,7 +36,7 @@ public class FunctionDSL { public static DefaultFunctionResolver define(FunctionName functionName, SerializableFunction>... functions) { - return define(functionName, Arrays.asList(functions)); + return define(functionName, List.of(functions)); } /** @@ -48,8 +49,7 @@ public static DefaultFunctionResolver define(FunctionName functionName, public static DefaultFunctionResolver define(FunctionName functionName, List< SerializableFunction>> functions) { - DefaultFunctionResolver.DefaultFunctionResolverBuilder builder - = DefaultFunctionResolver.builder(); + DefaultFunctionResolverBuilder builder = DefaultFunctionResolver.builder(); builder.functionName(functionName); for (Function> func : functions) { Pair functionBuilder = func.apply(functionName); @@ -58,51 +58,54 @@ public static DefaultFunctionResolver define(FunctionName functionName, List< return builder.build(); } + /** - * No Arg Function Implementation. + * Implementation of no args function that uses FunctionProperties. * - * @param function {@link ExprValue} based unary function. - * @param returnType return type. - * @return Unary Function Implementation. + * @param function {@link ExprValue} based no args function. + * @param returnType function return type. + * @return no args function implementation. */ - public static SerializableFunction> impl( - SerializableNoArgFunction function, - ExprType returnType) { - + public static SerializableFunction> + implWithProperties(SerializableFunction function, + ExprType returnType) { return functionName -> { FunctionSignature functionSignature = new FunctionSignature(functionName, Collections.emptyList()); FunctionBuilder functionBuilder = - arguments -> new FunctionExpression(functionName, Collections.emptyList()) { - @Override - public ExprValue valueOf(Environment valueEnv) { - return function.get(); - } + (functionProperties, arguments) -> + new FunctionExpression(functionName, Collections.emptyList()) { + @Override + public ExprValue valueOf(Environment valueEnv) { + return function.apply(functionProperties); + } - @Override - public ExprType type() { - return returnType; - } + @Override + public ExprType type() { + return returnType; + } - @Override - public String toString() { - return String.format("%s()", functionName); - } - }; + @Override + public String toString() { + return String.format("%s()", functionName); + } + }; return Pair.of(functionSignature, functionBuilder); }; } /** - * Unary Function Implementation. + * Implementation of a function that takes one argument, returns a value, and + * requires FunctionProperties to complete. * * @param function {@link ExprValue} based unary function. * @param returnType return type. - * @param argsType argument type. + * @param argsType argument type. * @return Unary Function Implementation. */ - public static SerializableFunction> impl( - SerializableFunction function, + public static SerializableFunction> + implWithProperties( + SerializableBiFunction function, ExprType returnType, ExprType argsType) { @@ -110,11 +113,11 @@ public static SerializableFunction new FunctionExpression(functionName, arguments) { + (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { @Override public ExprValue valueOf(Environment valueEnv) { ExprValue value = arguments.get(0).valueOf(valueEnv); - return function.apply(value); + return function.apply(functionProperties, value); } @Override @@ -134,6 +137,35 @@ public String toString() { }; } + /** + * No Arg Function Implementation. + * + * @param function {@link ExprValue} based unary function. + * @param returnType return type. + * @return Unary Function Implementation. + */ + public static SerializableFunction> impl( + SerializableNoArgFunction function, + ExprType returnType) { + return implWithProperties(fp -> function.get(), returnType); + } + + /** + * Unary Function Implementation. + * + * @param function {@link ExprValue} based unary function. + * @param returnType return type. + * @param argsType argument type. + * @return Unary Function Implementation. + */ + public static SerializableFunction> impl( + SerializableFunction function, + ExprType returnType, + ExprType argsType) { + + return implWithProperties((fp, arg) -> function.apply(arg), returnType, argsType); + } + /** * Binary Function Implementation. * @@ -153,7 +185,7 @@ public static SerializableFunction new FunctionExpression(functionName, arguments) { + (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { @Override public ExprValue valueOf(Environment valueEnv) { ExprValue arg1 = arguments.get(0).valueOf(valueEnv); @@ -196,7 +228,7 @@ public static SerializableFunction new FunctionExpression(functionName, arguments) { + (functionProperties, arguments) -> new FunctionExpression(functionName, arguments) { @Override public ExprValue valueOf(Environment valueEnv) { ExprValue arg1 = arguments.get(0).valueOf(valueEnv); @@ -267,4 +299,22 @@ public SerializableTriFunction nullM } }; } + + /** + * Wrapper the unary ExprValue function that is aware of FunctionProperties, + * with default NULL and MISSING handling. + */ + public static SerializableBiFunction + nullMissingHandlingWithProperties( + SerializableBiFunction implementation) { + return (functionProperties, v1) -> { + if (v1.isMissing()) { + return ExprValueUtils.missingValue(); + } else if (v1.isNull()) { + return ExprValueUtils.nullValue(); + } else { + return implementation.apply(functionProperties, v1); + } + }; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/FunctionProperties.java b/core/src/main/java/org/opensearch/sql/expression/function/FunctionProperties.java new file mode 100644 index 0000000000..4222748051 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/function/FunctionProperties.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.io.Serializable; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneId; +import lombok.EqualsAndHashCode; +import lombok.RequiredArgsConstructor; + +@RequiredArgsConstructor +@EqualsAndHashCode +public class FunctionProperties implements Serializable { + + private final Instant nowInstant; + private final ZoneId currentZoneId; + + /** + * By default, use current time and current timezone. + */ + public FunctionProperties() { + nowInstant = Instant.now(); + currentZoneId = ZoneId.systemDefault(); + } + + /** + * Method to access current system clock. + * @return a ticking clock that tells the time. + */ + public Clock getSystemClock() { + return Clock.system(currentZoneId); + } + + /** + * Method to get time when query began execution. + * Clock class combines an instant Supplier and a time zone. + * @return a fixed clock that returns the time execution started at. + * + */ + public Clock getQueryStartClock() { + return Clock.fixed(nowInstant, currentZoneId); + } + + /** + * Use when compiling functions that do not rely on function properties. + */ + public static final FunctionProperties None = new FunctionProperties() { + @Override + public Clock getSystemClock() { + throw new UnexpectedCallException(); + } + + @Override + public Clock getQueryStartClock() { + throw new UnexpectedCallException(); + } + }; + + class UnexpectedCallException extends RuntimeException { + public UnexpectedCallException() { + super("FunctionProperties.None is a null object and not meant to be accessed."); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java index e781db8c84..7066622e1b 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/RelevanceFunctionResolver.java @@ -47,8 +47,8 @@ public Pair resolve(FunctionSignature unreso } } - FunctionBuilder buildFunction = - args -> new OpenSearchFunctions.OpenSearchFunction(functionName, args); + FunctionBuilder buildFunction = (functionProperties, args) + -> new OpenSearchFunctions.OpenSearchFunction(functionName, args); return Pair.of(unresolvedSignature, buildFunction); } diff --git a/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java b/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java index 5e955c2e62..6d8dd5093e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/system/SystemFunctions.java @@ -38,7 +38,8 @@ private static FunctionResolver typeof() { public Pair resolve( FunctionSignature unresolvedSignature) { return Pair.of(unresolvedSignature, - arguments -> new FunctionExpression(BuiltinFunctionName.TYPEOF.getName(), arguments) { + (functionProperties, arguments) -> + new FunctionExpression(BuiltinFunctionName.TYPEOF.getName(), arguments) { @Override public ExprValue valueOf(Environment valueEnv) { return new ExprStringValue(getArguments().get(0).type().toString()); diff --git a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java index a3baf08ff3..9a9e0c4c86 100644 --- a/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/window/WindowFunctions.java @@ -54,9 +54,8 @@ private DefaultFunctionResolver denseRank() { private DefaultFunctionResolver rankingFunction(FunctionName functionName, Supplier constructor) { FunctionSignature functionSignature = new FunctionSignature(functionName, emptyList()); - FunctionBuilder functionBuilder = arguments -> constructor.get(); + FunctionBuilder functionBuilder = (functionProperties, arguments) -> constructor.get(); return new DefaultFunctionResolver(functionName, ImmutableMap.of(functionSignature, functionBuilder)); } - } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index a485541a34..c68ba2653e 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -155,7 +155,8 @@ public Pair resolve( FunctionSignature functionSignature = new FunctionSignature(functionName, List.of(STRING, LONG, LONG, LONG)); return Pair.of(functionSignature, - args -> new TestTableFunctionImplementation(functionName, args, table)); + (functionProperties, args) -> new TestTableFunctionImplementation(functionName, args, + table)); } @Override diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index f10fc281b3..dfb7a7239f 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -48,7 +48,7 @@ import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.expression.function.FunctionPropertiesTestConfig; import org.opensearch.sql.expression.window.aggregation.AggregateWindowFunction; import org.springframework.context.annotation.Configuration; import org.springframework.test.context.ContextConfiguration; @@ -56,7 +56,7 @@ @Configuration @ExtendWith(SpringExtension.class) -@ContextConfiguration(classes = {AnalyzerTestBase.class}) +@ContextConfiguration(classes = {FunctionPropertiesTestConfig.class, AnalyzerTestBase.class}) class ExpressionAnalyzerTest extends AnalyzerTestBase { @Test @@ -564,27 +564,12 @@ public void match_phrase_prefix_all_params() { ); } - @Test - public void constant_function_is_calculated_on_analyze() { - // Actually, we can call any function as ConstantFunction to be calculated on analyze stage - assertTrue(analyze(AstDSL.constantFunction("now")) instanceof LiteralExpression); - assertTrue(analyze(AstDSL.constantFunction("localtime")) instanceof LiteralExpression); - } - @Test public void function_isnt_calculated_on_analyze() { assertTrue(analyze(function("now")) instanceof FunctionExpression); assertTrue(analyze(AstDSL.function("localtime")) instanceof FunctionExpression); } - @Test - public void constant_function_returns_constant_cached_value() { - var values = List.of(analyze(AstDSL.constantFunction("now")), - analyze(AstDSL.constantFunction("now")), analyze(AstDSL.constantFunction("now"))); - assertTrue(values.stream().allMatch(v -> - v.valueOf() == analyze(AstDSL.constantFunction("now")).valueOf())); - } - @Test public void function_returns_non_constant_value() { // Even a function returns the same values - they are calculated on each call diff --git a/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java b/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java index d2154e26f4..c73bd8ac18 100644 --- a/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java +++ b/core/src/test/java/org/opensearch/sql/expression/ExpressionTestBase.java @@ -35,6 +35,8 @@ import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.env.Environment; import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; +import org.opensearch.sql.expression.function.FunctionPropertiesTestConfig; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; @@ -43,10 +45,13 @@ @Configuration @ExtendWith(SpringExtension.class) -@ContextConfiguration(classes = {ExpressionTestBase.class, +@ContextConfiguration(classes = {FunctionPropertiesTestConfig.class, ExpressionTestBase.class, TestConfig.class}) public class ExpressionTestBase { + @Autowired + protected FunctionProperties functionProperties; + @Autowired protected Environment typeEnv; diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java b/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java index 7c8464e79c..555759f1b1 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/DateTimeTestBase.java @@ -5,21 +5,19 @@ package org.opensearch.sql.expression.datetime; -import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; -import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; -import static org.opensearch.sql.expression.function.BuiltinFunctionRepository.DEFAULT_NAMESPACE; - import java.time.Instant; import java.time.LocalDate; import java.time.LocalDateTime; import java.time.LocalTime; -import java.util.Collections; import java.util.List; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.data.model.ExprDateValue; import org.opensearch.sql.data.model.ExprDatetimeValue; +import org.opensearch.sql.data.model.ExprMissingValue; +import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprTimestampValue; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.expression.DSL; @@ -27,113 +25,114 @@ import org.opensearch.sql.expression.ExpressionTestBase; import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; -import org.opensearch.sql.expression.function.FunctionName; -import org.opensearch.sql.expression.function.FunctionSignature; -import org.springframework.beans.factory.annotation.Autowired; +import org.opensearch.sql.expression.function.FunctionProperties; @ExtendWith(MockitoExtension.class) public class DateTimeTestBase extends ExpressionTestBase { + protected final BuiltinFunctionRepository functionRepository + = BuiltinFunctionRepository.getInstance(); + @Mock protected Environment env; - @Mock - protected Expression nullRef; - @Mock - protected Expression missingRef; + protected static FunctionProperties functionProperties; + + @BeforeAll + public static void setup() { + functionProperties = new FunctionProperties(); + } + + protected Expression nullRef = DSL.literal(ExprNullValue.of()); + + protected Expression missingRef = DSL.literal(ExprMissingValue.of()); protected ExprValue eval(Expression expression) { return expression.valueOf(env); } + protected LocalDateTime fromUnixTime(Double value) { + return fromUnixTime(DSL.literal(value)).valueOf().datetimeValue(); + } + protected FunctionExpression fromUnixTime(Expression value) { - var func = BuiltinFunctionRepository.getInstance() - .resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("from_unixtime"), - List.of(value.type()))); - return (FunctionExpression)func.apply(List.of(value)); + return (FunctionExpression) + functionRepository.compile(functionProperties, + BuiltinFunctionName.FROM_UNIXTIME.getName(), List.of(value)); } protected FunctionExpression fromUnixTime(Expression value, Expression format) { - var func = BuiltinFunctionRepository.getInstance() - .resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("from_unixtime"), - List.of(value.type(), format.type()))); - return (FunctionExpression)func.apply(List.of(value, format)); + return (FunctionExpression) + functionRepository.compile( + functionProperties, + BuiltinFunctionName.FROM_UNIXTIME.getName(), List.of(value, format)); } protected LocalDateTime fromUnixTime(Long value) { return fromUnixTime(DSL.literal(value)).valueOf().datetimeValue(); } - protected LocalDateTime fromUnixTime(Double value) { - return fromUnixTime(DSL.literal(value)).valueOf().datetimeValue(); - } - protected String fromUnixTime(Long value, String format) { - return fromUnixTime(DSL.literal(value), DSL.literal(format)).valueOf().stringValue(); + return fromUnixTime(DSL.literal(value), DSL.literal(format)) + .valueOf().stringValue(); } protected String fromUnixTime(Double value, String format) { - return fromUnixTime(DSL.literal(value), DSL.literal(format)).valueOf().stringValue(); - } - - protected FunctionExpression makedate(Expression year, Expression dayOfYear) { - var func = BuiltinFunctionRepository.getInstance() - .resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("makedate"), - List.of(DOUBLE, DOUBLE))); - return (FunctionExpression)func.apply(List.of(year, dayOfYear)); - } - - protected LocalDate makedate(Double year, Double dayOfYear) { - return makedate(DSL.literal(year), DSL.literal(dayOfYear)).valueOf().dateValue(); + return fromUnixTime(DSL.literal(value), DSL.literal(format)) + .valueOf().stringValue(); } protected FunctionExpression maketime(Expression hour, Expression minute, Expression second) { - var func = BuiltinFunctionRepository.getInstance() - .resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("maketime"), - List.of(DOUBLE, DOUBLE, DOUBLE))); - return (FunctionExpression)func.apply(List.of(hour, minute, second)); + return (FunctionExpression) + functionRepository.compile( + functionProperties, + BuiltinFunctionName.MAKETIME.getName(), List.of(hour, minute, second)); } + protected LocalTime maketime(Double hour, Double minute, Double second) { return maketime(DSL.literal(hour), DSL.literal(minute), DSL.literal(second)) .valueOf().timeValue(); } + protected FunctionExpression makedate(Expression year, Expression dayOfYear) { + return (FunctionExpression) functionRepository.compile( + functionProperties, + BuiltinFunctionName.MAKEDATE.getName(), List.of(year, dayOfYear)); + } + + protected LocalDate makedate(double year, double dayOfYear) { + return makedate(DSL.literal(year), DSL.literal(dayOfYear)).valueOf(null).dateValue(); + } + protected FunctionExpression period_add(Expression period, Expression months) { - var func = BuiltinFunctionRepository.getInstance() - .resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("period_add"), - List.of(INTEGER, INTEGER))); - return (FunctionExpression)func.apply(List.of(period, months)); + return (FunctionExpression) functionRepository.compile( + functionProperties, + BuiltinFunctionName.PERIOD_ADD.getName(), List.of(period, months)); } protected Integer period_add(Integer period, Integer months) { - return period_add(DSL.literal(period), DSL.literal(months)).valueOf().integerValue(); + return period_add(DSL.literal(period), DSL.literal(months)) + .valueOf().integerValue(); } protected FunctionExpression period_diff(Expression first, Expression second) { - var func = BuiltinFunctionRepository.getInstance() - .resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("period_diff"), - List.of(INTEGER, INTEGER))); - return (FunctionExpression)func.apply(List.of(first, second)); + return (FunctionExpression) functionRepository.compile( + functionProperties, + BuiltinFunctionName.PERIOD_DIFF.getName(), List.of(first, second)); } protected Integer period_diff(Integer first, Integer second) { - return period_diff(DSL.literal(first), DSL.literal(second)).valueOf().integerValue(); + return period_diff(DSL.literal(first), DSL.literal(second)) + .valueOf().integerValue(); } protected FunctionExpression unixTimeStampExpr() { - var func = BuiltinFunctionRepository.getInstance() - .resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("unix_timestamp"), List.of())); - return (FunctionExpression)func.apply(List.of()); + return (FunctionExpression) functionRepository.compile( + functionProperties, BuiltinFunctionName.UNIX_TIMESTAMP.getName(), List.of()); } protected Long unixTimeStamp() { @@ -141,26 +140,28 @@ protected Long unixTimeStamp() { } protected FunctionExpression unixTimeStampOf(Expression value) { - var func = BuiltinFunctionRepository.getInstance() - .resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(new FunctionName("unix_timestamp"), - List.of(value.type()))); - return (FunctionExpression)func.apply(List.of(value)); + return (FunctionExpression) + functionRepository.compile(functionProperties, + BuiltinFunctionName.UNIX_TIMESTAMP.getName(), List.of(value)); } + protected Double unixTimeStampOf(Double value) { return unixTimeStampOf(DSL.literal(value)).valueOf().doubleValue(); } protected Double unixTimeStampOf(LocalDate value) { - return unixTimeStampOf(DSL.literal(new ExprDateValue(value))).valueOf().doubleValue(); + return unixTimeStampOf(DSL.literal(new ExprDateValue(value))) + .valueOf().doubleValue(); } protected Double unixTimeStampOf(LocalDateTime value) { - return unixTimeStampOf(DSL.literal(new ExprDatetimeValue(value))).valueOf().doubleValue(); + return unixTimeStampOf(DSL.literal(new ExprDatetimeValue(value))) + .valueOf().doubleValue(); } protected Double unixTimeStampOf(Instant value) { - return unixTimeStampOf(DSL.literal(new ExprTimestampValue(value))).valueOf().doubleValue(); + return unixTimeStampOf(DSL.literal(new ExprTimestampValue(value))) + .valueOf().doubleValue(); } } diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/MakeDateTest.java b/core/src/test/java/org/opensearch/sql/expression/datetime/MakeDateTest.java index 981468f31e..87cbc7ae48 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/MakeDateTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/MakeDateTest.java @@ -51,8 +51,6 @@ public void checkRounding() { @Test public void checkNullValues() { - when(nullRef.valueOf(env)).thenReturn(nullValue()); - assertEquals(nullValue(), eval(makedate(nullRef, DSL.literal(42.)))); assertEquals(nullValue(), eval(makedate(DSL.literal(42.), nullRef))); assertEquals(nullValue(), eval(makedate(nullRef, nullRef))); @@ -60,8 +58,6 @@ public void checkNullValues() { @Test public void checkMissingValues() { - when(missingRef.valueOf(env)).thenReturn(missingValue()); - assertEquals(missingValue(), eval(makedate(missingRef, DSL.literal(42.)))); assertEquals(missingValue(), eval(makedate(DSL.literal(42.), missingRef))); assertEquals(missingValue(), eval(makedate(missingRef, missingRef))); diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/MakeTimeTest.java b/core/src/test/java/org/opensearch/sql/expression/datetime/MakeTimeTest.java index 7dd78ff845..3fb2472c18 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/MakeTimeTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/MakeTimeTest.java @@ -58,8 +58,6 @@ public void checkSecondFraction() { @Test public void checkNullValues() { - when(nullRef.valueOf(env)).thenReturn(nullValue()); - assertEquals(nullValue(), eval(maketime(nullRef, DSL.literal(42.), DSL.literal(42.)))); assertEquals(nullValue(), eval(maketime(DSL.literal(42.), nullRef, DSL.literal(42.)))); assertEquals(nullValue(), eval(maketime(DSL.literal(42.), DSL.literal(42.), nullRef))); @@ -71,8 +69,6 @@ public void checkNullValues() { @Test public void checkMissingValues() { - when(missingRef.valueOf(env)).thenReturn(missingValue()); - assertEquals(missingValue(), eval(maketime(missingRef, DSL.literal(42.), DSL.literal(42.)))); assertEquals(missingValue(), eval(maketime(DSL.literal(42.), missingRef, DSL.literal(42.)))); assertEquals(missingValue(), eval(maketime(DSL.literal(42.), DSL.literal(42.), missingRef))); diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/NowLikeFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/datetime/NowLikeFunctionTest.java index 437f484e80..6f7548b5cb 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/NowLikeFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/NowLikeFunctionTest.java @@ -6,6 +6,7 @@ package org.opensearch.sql.expression.datetime; +import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -18,92 +19,119 @@ import java.time.LocalDateTime; import java.time.LocalTime; import java.time.Period; +import java.time.temporal.ChronoUnit; import java.time.temporal.Temporal; +import java.time.temporal.TemporalUnit; import java.util.List; +import java.util.concurrent.Callable; +import java.util.function.BiFunction; import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Stream; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.Arguments; -import org.junit.jupiter.params.provider.MethodSource; +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.core.IsNot; +import org.junit.jupiter.api.DynamicTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestFactory; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.ExpressionTestBase; import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.function.FunctionProperties; -public class NowLikeFunctionTest extends ExpressionTestBase { - private static Stream functionNames() { - return Stream.of( - Arguments.of((Function) DSL::now, - "now", DATETIME, false, (Supplier)LocalDateTime::now), - Arguments.of((Function) DSL::current_timestamp, - "current_timestamp", DATETIME, false, (Supplier)LocalDateTime::now), - Arguments.of((Function) DSL::localtimestamp, - "localtimestamp", DATETIME, false, (Supplier)LocalDateTime::now), - Arguments.of((Function) DSL::localtime, - "localtime", DATETIME, false, (Supplier)LocalDateTime::now), - Arguments.of((Function) DSL::sysdate, - "sysdate", DATETIME, true, (Supplier)LocalDateTime::now), - Arguments.of((Function) DSL::curtime, - "curtime", TIME, false, (Supplier)LocalTime::now), - Arguments.of((Function) DSL::current_time, - "current_time", TIME, false, (Supplier)LocalTime::now), - Arguments.of((Function) DSL::curdate, - "curdate", DATE, false, (Supplier)LocalDate::now), - Arguments.of((Function) DSL::current_date, - "current_date", DATE, false, (Supplier)LocalDate::now)); +class NowLikeFunctionTest extends ExpressionTestBase { + @Test + void now() { + test_now_like_functions(DSL::now, + DATETIME, + false, + () -> LocalDateTime.now(functionProperties.getQueryStartClock())); } - private Temporal extractValue(FunctionExpression func) { - switch ((ExprCoreType)func.type()) { - case DATE: return func.valueOf().dateValue(); - case DATETIME: return func.valueOf().datetimeValue(); - case TIME: return func.valueOf().timeValue(); - // unreachable code - default: throw new IllegalArgumentException(String.format("%s", func.type())); - } + @Test + void current_timestamp() { + test_now_like_functions(DSL::current_timestamp, DATETIME, false, + () -> LocalDateTime.now(functionProperties.getQueryStartClock())); } - private long getDiff(Temporal sample, Temporal reference) { - if (sample instanceof LocalDate) { - return Period.between((LocalDate) sample, (LocalDate) reference).getDays(); - } - return Duration.between(sample, reference).toSeconds(); + @Test + void localtimestamp() { + test_now_like_functions(DSL::localtimestamp, DATETIME, false, + () -> LocalDateTime.now(functionProperties.getQueryStartClock())); + } + + @Test + void localtime() { + test_now_like_functions(DSL::localtime, DATETIME, false, + () -> LocalDateTime.now(functionProperties.getQueryStartClock())); + } + + @Test + void sysdate() { + test_now_like_functions(DSL::sysdate, DATETIME, true, LocalDateTime::now); + } + + @Test + void curtime() { + test_now_like_functions(DSL::curtime, TIME, false, + () -> LocalTime.now(functionProperties.getQueryStartClock())); + } + + @Test + void currdate() { + + test_now_like_functions(DSL::curdate, + DATE, false, + () -> LocalDate.now(functionProperties.getQueryStartClock())); + } + + @Test + void current_time() { + test_now_like_functions(DSL::current_time, + TIME, + false, + () -> LocalTime.now(functionProperties.getQueryStartClock())); + } + + @Test + void current_date() { + test_now_like_functions(DSL::current_date, DATE, false, + () -> LocalDate.now(functionProperties.getQueryStartClock())); } /** * Check how NOW-like functions are processed. - * @param function Function - * @param name Function name - * @param resType Return type - * @param hasFsp Whether function has fsp argument + * + * @param function Function + * @param resType Return type + * @param hasFsp Whether function has fsp argument * @param referenceGetter A callback to get reference value */ - @ParameterizedTest(name = "{1}") - @MethodSource("functionNames") - public void test_now_like_functions(Function function, - @SuppressWarnings("unused") // Used in the test name above - String name, - ExprCoreType resType, - Boolean hasFsp, - Supplier referenceGetter) { + void test_now_like_functions( + BiFunction function, + ExprCoreType resType, + Boolean hasFsp, + Supplier referenceGetter) { // Check return types: // `func()` - FunctionExpression expr = function.apply(new Expression[]{}); + FunctionExpression expr = function.apply(functionProperties, new Expression[] {}); assertEquals(resType, expr.type()); if (hasFsp) { // `func(fsp = 0)` - expr = function.apply(new Expression[]{DSL.literal(0)}); + expr = function.apply(functionProperties, new Expression[] {DSL.literal(0)}); assertEquals(resType, expr.type()); // `func(fsp = 6)` - expr = function.apply(new Expression[]{DSL.literal(6)}); + expr = function.apply(functionProperties, new Expression[] {DSL.literal(6)}); assertEquals(resType, expr.type()); - for (var wrongFspValue: List.of(-1, 10)) { + for (var wrongFspValue : List.of(-1, 10)) { var exception = assertThrows(IllegalArgumentException.class, - () -> function.apply(new Expression[]{DSL.literal(wrongFspValue)}).valueOf()); + () -> function.apply(functionProperties, + new Expression[] {DSL.literal(wrongFspValue)}).valueOf()); assertEquals(String.format("Invalid `fsp` value: %d, allowed 0 to 6", wrongFspValue), exception.getMessage()); } @@ -111,16 +139,89 @@ public void test_now_like_functions(Function f // Check how calculations are precise: // `func()` - assertTrue(Math.abs(getDiff( - extractValue(function.apply(new Expression[]{})), - referenceGetter.get() - )) <= 1); + Temporal sample = extractValue(function.apply(functionProperties, new Expression[] {})); + Temporal reference = referenceGetter.get(); + long maxDiff = 1; + TemporalUnit unit = resType.isCompatible(DATE) ? ChronoUnit.DAYS : ChronoUnit.SECONDS; + assertThat(sample, isCloseTo(reference, maxDiff, unit)); if (hasFsp) { // `func(fsp)` - assertTrue(Math.abs(getDiff( - extractValue(function.apply(new Expression[]{DSL.literal(0)})), - referenceGetter.get() - )) <= 1); + Temporal value = extractValue(function.apply(functionProperties, + new Expression[] {DSL.literal(0)})); + assertThat(referenceGetter.get(), + isCloseTo(value, maxDiff, unit)); + + } + } + + static Matcher isCloseTo(Temporal reference, long maxDiff, TemporalUnit units) { + return new BaseMatcher<>() { + @Override + public void describeTo(Description description) { + description.appendText("value between ") + .appendValue(reference.minus(maxDiff, units)) + .appendText(" and ") + .appendValue(reference.plus(maxDiff, units)); + } + + @Override + public boolean matches(Object value) { + if (value instanceof Temporal) { + Temporal temporalValue = (Temporal) value; + long diff = reference.until(temporalValue, units); + return Math.abs(diff) <= maxDiff; + } + return false; + } + + + }; + } + + @TestFactory + Stream constantValueTestFactory() { + BiFunction, DynamicTest> buildTest + = (name, action) -> + DynamicTest.dynamicTest( + String.format("multiple_invocations_same_value_test[%s]", name), + () -> { + var v1 = extractValue(action.apply(functionProperties)); + Thread.sleep(1000); + var v2 = extractValue(action.apply(functionProperties)); + assertEquals(v1, v2); + } + ); + return Stream.of( + buildTest.apply("now", DSL::now), + buildTest.apply("current_timestamp", DSL::current_timestamp), + buildTest.apply("current_time", DSL::current_time), + buildTest.apply("curdate", DSL::curdate), + buildTest.apply("curtime", DSL::curtime), + buildTest.apply("localtimestamp", DSL::localtimestamp), + buildTest.apply("localtime", DSL::localtime) + ); + } + + @Test + void sysdate_multiple_invocations_differ() throws InterruptedException { + var v1 = extractValue(DSL.sysdate(functionProperties)); + Thread.sleep(1000); + var v2 = extractValue(DSL.sysdate(functionProperties)); + assertThat(v1, IsNot.not(isCloseTo(v2, 1, ChronoUnit.NANOS))); + + } + + private Temporal extractValue(FunctionExpression func) { + switch ((ExprCoreType) func.type()) { + case DATE: + return func.valueOf().dateValue(); + case DATETIME: + return func.valueOf().datetimeValue(); + case TIME: + return func.valueOf().timeValue(); + // unreachable code + default: + throw new IllegalArgumentException(String.format("%s", func.type())); } } } diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTimeStampTest.java b/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTimeStampTest.java index 4e7541177f..f6e24f4e27 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTimeStampTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTimeStampTest.java @@ -30,8 +30,10 @@ public class UnixTimeStampTest extends DateTimeTestBase { @Test public void checkNoArgs() { - assertEquals(System.currentTimeMillis() / 1000L, unixTimeStamp()); - assertEquals(System.currentTimeMillis() / 1000L, eval(unixTimeStampExpr()).longValue()); + + final long expected = functionProperties.getQueryStartClock().millis() / 1000L; + assertEquals(expected, unixTimeStamp()); + assertEquals(expected, eval(unixTimeStampExpr()).longValue()); } private static Stream getDateSamples() { diff --git a/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTwoWayConversionTest.java b/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTwoWayConversionTest.java index a5ddee2d0c..cb4d318eb3 100644 --- a/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTwoWayConversionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/datetime/UnixTwoWayConversionTest.java @@ -24,9 +24,18 @@ public class UnixTwoWayConversionTest extends DateTimeTestBase { @Test public void checkConvertNow() { - assertEquals(LocalDateTime.now(ZoneId.of("UTC")).withNano(0), fromUnixTime(unixTimeStamp())); - assertEquals(LocalDateTime.now(ZoneId.of("UTC")).withNano(0), - eval(fromUnixTime(unixTimeStampExpr())).datetimeValue()); + assertEquals(getExpectedNow(), fromUnixTime(unixTimeStamp())); + } + + @Test + public void checkConvertNow_with_eval() { + assertEquals(getExpectedNow(), eval(fromUnixTime(unixTimeStampExpr())).datetimeValue()); + } + + private LocalDateTime getExpectedNow() { + return LocalDateTime.now( + functionProperties.getQueryStartClock().withZone(ZoneId.of("UTC"))) + .withNano(0); } private static Stream getDoubleSamples() { diff --git a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java index f63304e6b5..8bba3bd9b9 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/BuiltinFunctionRepositoryTest.java @@ -58,6 +58,8 @@ class BuiltinFunctionRepositoryTest { @Mock private Map mockMap; @Mock + FunctionProperties functionProperties; + @Mock private FunctionName mockFunctionName; @Mock private FunctionBuilder functionExpressionBuilder; @@ -114,8 +116,9 @@ void compile() { BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap); repo.register(mockfunctionResolver); - repo.compile(mockFunctionName, Arrays.asList(mockExpression)); - verify(functionExpressionBuilder, times(1)).apply(any()); + repo.compile(functionProperties, mockFunctionName, Arrays.asList(mockExpression)); + verify(functionExpressionBuilder, times(1)) + .apply(eq(functionProperties), any()); } @@ -133,8 +136,10 @@ void compile_function_under_datasource_namespace() { BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap); repo.register(TEST_NAMESPACE, mockfunctionResolver); - repo.compile(TEST_NAMESPACE, mockFunctionName, Arrays.asList(mockExpression)); - verify(functionExpressionBuilder, times(1)).apply(any()); + repo.compile(functionProperties, TEST_NAMESPACE, mockFunctionName, + Arrays.asList(mockExpression)); + verify(functionExpressionBuilder, times(1)) + .apply(eq(functionProperties), any()); } @Test @@ -151,17 +156,18 @@ void resolve() { BuiltinFunctionRepository repo = new BuiltinFunctionRepository(mockNamespaceMap); repo.register(mockfunctionResolver); - assertEquals(functionExpressionBuilder, - repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), functionSignature)); + assertEquals(functionExpressionBuilder, repo.resolve( + Collections.singletonList(DEFAULT_NAMESPACE), functionSignature)); } @Test void resolve_should_not_cast_arguments_in_cast_function() { when(mockExpression.toString()).thenReturn("string"); FunctionImplementation function = - repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), + repo.resolve( + Collections.singletonList(DEFAULT_NAMESPACE), registerFunctionResolver(CAST_TO_BOOLEAN.getName(), DATETIME, BOOLEAN)) - .apply(ImmutableList.of(mockExpression)); + .apply(functionProperties, ImmutableList.of(mockExpression)); assertEquals("cast_to_boolean(string)", function.toString()); } @@ -170,9 +176,10 @@ void resolve_should_not_cast_arguments_if_same_type() { when(mockFunctionName.getFunctionName()).thenReturn("mock"); when(mockExpression.toString()).thenReturn("string"); FunctionImplementation function = - repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), + repo.resolve( + Collections.singletonList(DEFAULT_NAMESPACE), registerFunctionResolver(mockFunctionName, STRING, STRING)) - .apply(ImmutableList.of(mockExpression)); + .apply(functionProperties, ImmutableList.of(mockExpression)); assertEquals("mock(string)", function.toString()); } @@ -181,9 +188,10 @@ void resolve_should_not_cast_arguments_if_both_numbers() { when(mockFunctionName.getFunctionName()).thenReturn("mock"); when(mockExpression.toString()).thenReturn("byte"); FunctionImplementation function = - repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), + repo.resolve( + Collections.singletonList(DEFAULT_NAMESPACE), registerFunctionResolver(mockFunctionName, BYTE, INTEGER)) - .apply(ImmutableList.of(mockExpression)); + .apply(functionProperties, ImmutableList.of(mockExpression)); assertEquals("mock(byte)", function.toString()); } @@ -199,7 +207,7 @@ void resolve_should_cast_arguments() { FunctionImplementation function = repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), signature) - .apply(ImmutableList.of(mockExpression)); + .apply(functionProperties, ImmutableList.of(mockExpression)); assertEquals("mock(cast_to_boolean(string))", function.toString()); } @@ -207,9 +215,10 @@ void resolve_should_cast_arguments() { void resolve_should_throw_exception_for_unsupported_conversion() { ExpressionEvaluationException error = assertThrows(ExpressionEvaluationException.class, () -> - repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), + repo.resolve( + Collections.singletonList(DEFAULT_NAMESPACE), registerFunctionResolver(mockFunctionName, BYTE, STRUCT)) - .apply(ImmutableList.of(mockExpression))); + .apply(functionProperties, ImmutableList.of(mockExpression))); assertEquals(error.getMessage(), "Type conversion to type STRUCT is not supported"); } @@ -223,8 +232,9 @@ void resolve_unregistered() { repo.register(mockfunctionResolver); ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, - () -> repo.resolve(Collections.singletonList(DEFAULT_NAMESPACE), - new FunctionSignature(FunctionName.of("unknown"), Arrays.asList()))); + () -> repo.resolve( + Collections.singletonList(DEFAULT_NAMESPACE), + new FunctionSignature(FunctionName.of("unknown"), List.of()))); assertEquals("unsupported function name: unknown", exception.getMessage()); } @@ -249,8 +259,8 @@ private FunctionSignature registerFunctionResolver(FunctionName funcName, // Relax unnecessary stubbing check because error case test doesn't call this lenient().doAnswer(invocation -> - new FakeFunctionExpression(funcName, invocation.getArgument(0)) - ).when(funcBuilder).apply(any()); + new FakeFunctionExpression(funcName, invocation.getArgument(1)) + ).when(funcBuilder).apply(eq(functionProperties), any()); return unresolvedSignature; } diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLDefineTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLDefineTest.java new file mode 100644 index 0000000000..8bf4d7ba24 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLDefineTest.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.sql.expression.function.FunctionDSL.define; + +import java.util.List; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.expression.Expression; + +@ExtendWith(MockitoExtension.class) +class FunctionDSLDefineTest extends FunctionDSLTestBase { + + @Test + void define_variableArgs_test() { + Pair resolvedA = + Pair.of(SAMPLE_SIGNATURE_A, new SampleFunctionBuilder()); + + FunctionResolver resolver = define(SAMPLE_NAME, v -> resolvedA); + + assertEquals(resolvedA, resolver.resolve(SAMPLE_SIGNATURE_A)); + } + + @Test + void define_test() { + Pair resolved = + Pair.of(SAMPLE_SIGNATURE_A, new SampleFunctionBuilder()); + + FunctionResolver resolver = define(SAMPLE_NAME, List.of(v -> resolved)); + + assertEquals(resolved, resolver.resolve(SAMPLE_SIGNATURE_A)); + } + + @Test + void define_name_test() { + Pair resolved = + Pair.of(SAMPLE_SIGNATURE_A, new SampleFunctionBuilder()); + + FunctionResolver resolver = define(SAMPLE_NAME, List.of(v -> resolved)); + + assertEquals(SAMPLE_NAME, resolver.getFunctionName()); + } + + static class SampleFunctionBuilder implements FunctionBuilder { + @Override + public FunctionImplementation apply(FunctionProperties functionProperties, + List arguments) { + return new SampleFunctionImplementation(arguments); + } + } + + @RequiredArgsConstructor + static class SampleFunctionImplementation implements FunctionImplementation { + @Getter + private final List arguments; + + @Override + public FunctionName getFunctionName() { + return SAMPLE_NAME; + } + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java new file mode 100644 index 0000000000..193066e626 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLTestBase.java @@ -0,0 +1,62 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.util.List; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprMissingValue; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; + +@ExtendWith(MockitoExtension.class) +public class FunctionDSLTestBase { + @Mock + FunctionProperties functionProperties; + + public static final ExprNullValue NULL = ExprNullValue.of(); + public static final ExprMissingValue MISSING = ExprMissingValue.of(); + protected static final ExprType ANY_TYPE = () -> "ANY"; + protected static final ExprValue ANY = new ExprValue() { + @Override + public Object value() { + throw new RuntimeException(); + } + + @Override + public ExprType type() { + return ANY_TYPE; + } + + @Override + public String toString() { + return "ANY"; + } + + @Override + public int compareTo(ExprValue o) { + throw new RuntimeException(); + } + }; + static final FunctionName SAMPLE_NAME = FunctionName.of("sample"); + static final FunctionSignature SAMPLE_SIGNATURE_A = + new FunctionSignature(SAMPLE_NAME, List.of(ExprCoreType.UNDEFINED)); + static final SerializableNoArgFunction noArg = () -> ANY; + static final SerializableFunction oneArg = v -> ANY; + static final SerializableBiFunction + oneArgWithProperties = (functionProperties, v) -> ANY; + + static final SerializableBiFunction + twoArgs = (v1, v2) -> ANY; + static final SerializableTriFunction + threeArgs = (v1, v2, v3) -> ANY; + @Mock + FunctionProperties mockProperties; +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplNoArgTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplNoArgTest.java new file mode 100644 index 0000000000..5d970803ed --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplNoArgTest.java @@ -0,0 +1,31 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.opensearch.sql.expression.function.FunctionDSL.impl; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeEach; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplNoArgTest extends FunctionDSLimplTestBase { + @Override + SerializableFunction> + getImplementationGenerator() { + return impl(noArg, ANY_TYPE); + } + + @Override + List getSampleArguments() { + return List.of(); + } + + @Override + String getExpected_toString() { + return "sample()"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplOneArgTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplOneArgTest.java new file mode 100644 index 0000000000..6e7c194f5d --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplOneArgTest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.opensearch.sql.expression.function.FunctionDSL.impl; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplOneArgTest extends FunctionDSLimplTestBase { + + @Override + SerializableFunction> + getImplementationGenerator() { + return impl(oneArg, ANY_TYPE, ANY_TYPE); + } + + @Override + List getSampleArguments() { + return List.of(DSL.literal(ANY)); + } + + @Override + String getExpected_toString() { + return "sample(ANY)"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTestBase.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTestBase.java new file mode 100644 index 0000000000..8f494c01c3 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTestBase.java @@ -0,0 +1,89 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; + +abstract class FunctionDSLimplTestBase extends FunctionDSLTestBase { + @Test + void implementationGenerator_is_valid() { + assertNotNull(getImplementationGenerator()); + } + + @Test + void implementation_is_valid_pair() { + assertNotNull(getImplementation().getKey()); + assertNotNull(getImplementation().getValue()); + } + + @Test + void implementation_expected_functionName() { + assertEquals(SAMPLE_NAME, getImplementation().getKey().getFunctionName()); + } + + @Test + void implementation_valid_functionBuilder() { + + FunctionBuilder v = getImplementation().getValue(); + assertDoesNotThrow(() -> v.apply(functionProperties, getSampleArguments())); + } + + @Test + void implementation_functionBuilder_return_functionExpression() { + FunctionImplementation executable = getImplementation().getValue() + .apply(functionProperties, getSampleArguments()); + assertTrue(executable instanceof FunctionExpression); + } + + @Test + void implementation_functionExpression_valueOf() { + FunctionExpression executable = + (FunctionExpression) getImplementation().getValue() + .apply(functionProperties, getSampleArguments()); + + assertEquals(ANY, executable.valueOf(null)); + } + + @Test + void implementation_functionExpression_type() { + FunctionExpression executable = + (FunctionExpression) getImplementation().getValue() + .apply(functionProperties, getSampleArguments()); + assertEquals(ANY_TYPE, executable.type()); + } + + @Test + void implementation_functionExpression_toString() { + FunctionExpression executable = + (FunctionExpression) getImplementation().getValue() + .apply(functionProperties, getSampleArguments()); + assertEquals(getExpected_toString(), executable.toString()); + } + + /** + * A lambda that takes a function name and returns an implementation + * of the function. + */ + abstract SerializableFunction> + getImplementationGenerator(); + + Pair getImplementation() { + return getImplementationGenerator().apply(SAMPLE_NAME); + } + + abstract List getSampleArguments(); + + abstract String getExpected_toString(); +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplThreeArgTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplThreeArgTest.java new file mode 100644 index 0000000000..eab8d24c59 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplThreeArgTest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.opensearch.sql.expression.function.FunctionDSL.impl; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplThreeArgTest extends FunctionDSLimplTestBase { + + @Override + SerializableFunction> + getImplementationGenerator() { + return impl(threeArgs, ANY_TYPE, ANY_TYPE, ANY_TYPE, ANY_TYPE); + } + + @Override + List getSampleArguments() { + return List.of(DSL.literal(ANY), DSL.literal(ANY), DSL.literal(ANY)); + } + + @Override + String getExpected_toString() { + return "sample(ANY, ANY, ANY)"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTwoArgTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTwoArgTest.java new file mode 100644 index 0000000000..87d097c9eb --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplTwoArgTest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.opensearch.sql.expression.function.FunctionDSL.impl; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplTwoArgTest extends FunctionDSLimplTestBase { + @Override + SerializableFunction> + getImplementationGenerator() { + return impl(twoArgs, ANY_TYPE, ANY_TYPE, ANY_TYPE); + } + + + @Override + List getSampleArguments() { + return List.of(DSL.literal(ANY), DSL.literal(ANY)); + } + + @Override + String getExpected_toString() { + return "sample(ANY, ANY)"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesNoArgsTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesNoArgsTest.java new file mode 100644 index 0000000000..c3c41b6c0c --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesNoArgsTest.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplWithPropertiesNoArgsTest extends FunctionDSLimplTestBase { + + @Override + SerializableFunction> + getImplementationGenerator() { + return FunctionDSL.implWithProperties(fp -> ANY, ANY_TYPE); + } + + @Override + List getSampleArguments() { + return List.of(); + } + + @Override + String getExpected_toString() { + return "sample()"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesOneArgTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesOneArgTest.java new file mode 100644 index 0000000000..4a05326c0a --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLimplWithPropertiesOneArgTest.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + + +import java.util.List; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; + +class FunctionDSLimplWithPropertiesOneArgTest extends FunctionDSLimplTestBase { + + @Override + SerializableFunction> + getImplementationGenerator() { + SerializableBiFunction functionBody + = (fp, arg) -> ANY; + return FunctionDSL.implWithProperties(functionBody, ANY_TYPE, ANY_TYPE); + } + + @Override + List getSampleArguments() { + return List.of(DSL.literal(ANY)); + } + + @Override + String getExpected_toString() { + return "sample(ANY)"; + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLnullMissingHandlingTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLnullMissingHandlingTest.java new file mode 100644 index 0000000000..64cac278f6 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionDSLnullMissingHandlingTest.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandling; +import static org.opensearch.sql.expression.function.FunctionDSL.nullMissingHandlingWithProperties; + +import org.junit.jupiter.api.Test; + +class FunctionDSLnullMissingHandlingTest extends FunctionDSLTestBase { + + @Test + void nullMissingHandling_oneArg_nullValue() { + assertEquals(NULL, nullMissingHandling(oneArg).apply(NULL)); + } + + @Test + void nullMissingHandling_oneArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(oneArg).apply(MISSING)); + } + + @Test + void nullMissingHandling_oneArg_apply() { + assertEquals(ANY, nullMissingHandling(oneArg).apply(ANY)); + } + + + @Test + void nullMissingHandling_oneArg_FunctionProperties_nullValue() { + assertEquals(NULL, + nullMissingHandlingWithProperties(oneArgWithProperties).apply(functionProperties, NULL)); + } + + @Test + void nullMissingHandling_oneArg_FunctionProperties_missingValue() { + assertEquals(MISSING, + nullMissingHandlingWithProperties(oneArgWithProperties).apply(functionProperties, MISSING)); + } + + @Test + void nullMissingHandling_oneArg_FunctionProperties_apply() { + assertEquals(ANY, + nullMissingHandlingWithProperties(oneArgWithProperties).apply(functionProperties, ANY)); + } + + @Test + void nullMissingHandling_twoArgs_firstArg_nullValue() { + assertEquals(NULL, nullMissingHandling(twoArgs).apply(NULL, ANY)); + } + + @Test + void nullMissingHandling_twoArgs_secondArg_nullValue() { + assertEquals(NULL, nullMissingHandling(twoArgs).apply(ANY, NULL)); + } + + + @Test + void nullMissingHandling_twoArgs_firstArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(twoArgs).apply(MISSING, ANY)); + } + + @Test + void nullMissingHandling_twoArgs_secondArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(twoArgs).apply(ANY, MISSING)); + } + + @Test + void nullMissingHandling_twoArgs_apply() { + assertEquals(ANY, nullMissingHandling(twoArgs).apply(ANY, ANY)); + } + + + @Test + void nullMissingHandling_threeArgs_firstArg_nullValue() { + assertEquals(NULL, nullMissingHandling(threeArgs).apply(NULL, ANY, ANY)); + } + + @Test + void nullMissingHandling_threeArgs_secondArg_nullValue() { + assertEquals(NULL, nullMissingHandling(threeArgs).apply(ANY, NULL, ANY)); + } + + @Test + void nullMissingHandling_threeArgs_thirdArg_nullValue() { + assertEquals(NULL, nullMissingHandling(threeArgs).apply(ANY, ANY, NULL)); + } + + + @Test + void nullMissingHandling_threeArgs_firstArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(threeArgs).apply(MISSING, ANY, ANY)); + } + + @Test + void nullMissingHandling_threeArg_secondArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(threeArgs).apply(ANY, MISSING, ANY)); + } + + @Test + void nullMissingHandling_threeArg_thirdArg_missingValue() { + assertEquals(MISSING, nullMissingHandling(threeArgs).apply(ANY, ANY, MISSING)); + } + + @Test + void nullMissingHandling_threeArg_apply() { + assertEquals(ANY, nullMissingHandling(threeArgs).apply(ANY, ANY, ANY)); + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionPropertiesTest.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionPropertiesTest.java new file mode 100644 index 0000000000..64ec21e7e1 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionPropertiesTest.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.time.Clock; +import java.time.Instant; +import java.time.ZoneId; +import java.util.concurrent.Callable; +import java.util.function.Consumer; +import java.util.function.Function; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.DynamicTest; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestFactory; +import org.junit.jupiter.api.function.Executable; + +class FunctionPropertiesTest { + + FunctionProperties functionProperties; + Instant startTime; + + @BeforeEach + void init() { + startTime = Instant.now(); + functionProperties = new FunctionProperties(startTime, ZoneId.systemDefault()); + } + + @Test + void getQueryStartClock_returns_constructor_instant() { + assertEquals(startTime, functionProperties.getQueryStartClock().instant()); + } + + @Test + void getQueryStartClock_differs_from_instantNow() throws InterruptedException { + // Give system clock a chance to advance. + Thread.sleep(1000); + assertNotEquals(Instant.now(), functionProperties.getQueryStartClock().instant()); + } + + @Test + void getSystemClock_is_systemClock() { + assertEquals(Clock.systemDefaultZone(), functionProperties.getSystemClock()); + } + + @Test + void functionProperties_can_be_serialized() throws IOException { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(functionProperties); + objectOutput.flush(); + assertNotEquals(0, output.size()); + } + + @Test + void functionProperties_can_be_deserialized() throws IOException, ClassNotFoundException { + ByteArrayOutputStream output = new ByteArrayOutputStream(); + ObjectOutputStream objectOutput = new ObjectOutputStream(output); + objectOutput.writeObject(functionProperties); + objectOutput.flush(); + ByteArrayInputStream input = new ByteArrayInputStream(output.toByteArray()); + ObjectInputStream objectInput = new ObjectInputStream(input); + assertEquals(functionProperties, objectInput.readObject()); + } + + @TestFactory + Stream functionProperties_none_throws_on_access() { + Consumer tb = tc -> { + RuntimeException e = assertThrows(FunctionProperties.UnexpectedCallException.class, tc); + assertEquals("FunctionProperties.None is a null object and not meant to be accessed.", + e.getMessage()); + }; + return Stream.of( + DynamicTest.dynamicTest("getQueryStartClock", + () -> tb.accept(FunctionProperties.None::getQueryStartClock)), + DynamicTest.dynamicTest("getSystemClock", + () -> tb.accept(FunctionProperties.None::getSystemClock)) + ); + } +} diff --git a/core/src/test/java/org/opensearch/sql/expression/function/FunctionPropertiesTestConfig.java b/core/src/test/java/org/opensearch/sql/expression/function/FunctionPropertiesTestConfig.java new file mode 100644 index 0000000000..dfc9b543ae --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/function/FunctionPropertiesTestConfig.java @@ -0,0 +1,17 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.function; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +@Configuration +public class FunctionPropertiesTestConfig { + @Bean + FunctionProperties functionProperties() { + return new FunctionProperties(); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java index f0fe041313..33050e7200 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/StandaloneIT.java @@ -32,6 +32,7 @@ import org.opensearch.sql.executor.QueryService; import org.opensearch.sql.executor.execution.QueryPlanFactory; import org.opensearch.sql.expression.function.BuiltinFunctionRepository; +import org.opensearch.sql.expression.function.FunctionProperties; import org.opensearch.sql.monitor.AlwaysHealthyMonitor; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.client.OpenSearchRestClient; @@ -75,6 +76,7 @@ public void init() { new OpenSearchExecutionProtector(new AlwaysHealthyMonitor()))); context.registerBean(OpenSearchClient.class, () -> client); context.registerBean(Settings.class, () -> defaultSettings()); + context.registerBean(FunctionProperties.class, FunctionProperties::new); DataSourceService dataSourceService = new DataSourceServiceImpl( new ImmutableSet.Builder() .add(new OpenSearchDataSourceFactory(client, defaultSettings())) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java index 602571434b..b62d545206 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/AggregationQueryBuilderTest.java @@ -57,7 +57,6 @@ @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @ExtendWith(MockitoExtension.class) class AggregationQueryBuilderTest { - @Mock private ExpressionSerializer serializer; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java index c6ec2d95c1..a0b9e5f318 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/MatchPhrasePrefixQueryTest.java @@ -25,7 +25,7 @@ import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchPhrasePrefixQuery; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) -public class MatchPhrasePrefixQueryTest { +public class MatchPhrasePrefixQueryTest { private final MatchPhrasePrefixQuery matchPhrasePrefixQuery = new MatchPhrasePrefixQuery(); private final FunctionName matchPhrasePrefix = FunctionName.of("match_phrase_prefix"); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java index ab3dc406b8..32c02959b8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/QueryStringTest.java @@ -45,7 +45,7 @@ class QueryStringTest { static Stream> generateValidData() { Expression field = DSL.namedArgument("fields", fields_value); Expression query = DSL.namedArgument("query", query_value); - return List.of( + return Stream.of( DSL.namedArgument("analyzer", DSL.literal("standard")), DSL.namedArgument("analyze_wildcard", DSL.literal("true")), DSL.namedArgument("allow_leading_wildcard", DSL.literal("true")), @@ -75,32 +75,32 @@ static Stream> generateValidData() { DSL.namedArgument("Allow_Leading_wildcard", DSL.literal("true")), DSL.namedArgument("Auto_Generate_Synonyms_Phrase_Query", DSL.literal("true")), DSL.namedArgument("Boost", DSL.literal("1")) - ).stream().map(arg -> List.of(field, query, arg)); + ).map(arg -> List.of(field, query, arg)); } @ParameterizedTest @MethodSource("generateValidData") - public void test_valid_parameters(List validArgs) { + void test_valid_parameters(List validArgs) { Assertions.assertNotNull(queryStringQuery.build( new QueryStringExpression(validArgs))); } @Test - public void test_SyntaxCheckException_when_no_arguments() { + void test_SyntaxCheckException_when_no_arguments() { List arguments = List.of(); assertThrows(SyntaxCheckException.class, () -> queryStringQuery.build(new QueryStringExpression(arguments))); } @Test - public void test_SyntaxCheckException_when_one_argument() { + void test_SyntaxCheckException_when_one_argument() { List arguments = List.of(namedArgument("fields", fields_value)); assertThrows(SyntaxCheckException.class, () -> queryStringQuery.build(new QueryStringExpression(arguments))); } @Test - public void test_SemanticCheckException_when_invalid_parameter() { + void test_SemanticCheckException_when_invalid_parameter() { List arguments = List.of( namedArgument("fields", fields_value), namedArgument("query", query_value), diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java index 678a77b82f..01ec85d64d 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/MultiFieldQueryTest.java @@ -25,9 +25,9 @@ import org.opensearch.sql.expression.LiteralExpression; class MultiFieldQueryTest { - MultiFieldQuery query; + MultiFieldQuery query; private final String testQueryName = "test_query"; - private final Map actionMap + private final Map> actionMap = ImmutableMap.of("paramA", (o, v) -> o); @BeforeEach diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java index f3e3bf5dfc..df6cfae78f 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/sort/SortQueryBuilderTest.java @@ -18,7 +18,7 @@ class SortQueryBuilderTest { - private SortQueryBuilder sortQueryBuilder = new SortQueryBuilder(); + private final SortQueryBuilder sortQueryBuilder = new SortQueryBuilder(); @Test void build_sortbuilder_from_reference() { diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 9a414d9bac..76d8e38eff 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -265,11 +265,6 @@ primaryExpression | dataTypeFunctionCall | fieldExpression | literalValue - | constantFunction - ; - -constantFunction - : constantFunctionName LT_PRTHS functionArgs? RT_PRTHS ; booleanExpression @@ -427,17 +422,51 @@ trigonometricFunctionName ; dateAndTimeFunctionBase - : ADDDATE | CONVERT_TZ | DATE | DATE_ADD | DATE_FORMAT | DATE_SUB - | DATETIME | DAY | DAYNAME | DAYOFMONTH | DAYOFWEEK | DAYOFYEAR | FROM_DAYS | FROM_UNIXTIME - | HOUR | MAKEDATE | MAKETIME | MICROSECOND | MINUTE | MONTH | MONTHNAME | PERIOD_ADD - | PERIOD_DIFF | QUARTER | SECOND | SUBDATE | SYSDATE | TIME | TIME_TO_SEC - | TIMESTAMP | TO_DAYS | UNIX_TIMESTAMP | WEEK | YEAR - ; - -// Functions which value could be cached in scope of a single query -constantFunctionName - : CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP | LOCALTIME | LOCALTIMESTAMP | UTC_TIMESTAMP | UTC_DATE | UTC_TIME - | CURDATE | CURTIME | NOW + : ADDDATE + | CONVERT_TZ + | CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | DATE + | DATE_ADD + | DATE_FORMAT + | DATE_SUB + | DATETIME + | DAY + | DAYNAME + | DAYOFMONTH + | DAYOFWEEK + | DAYOFYEAR + | CURDATE + | CURTIME + | FROM_DAYS + | FROM_UNIXTIME + | HOUR + | LOCALTIME + | LOCALTIMESTAMP + | MAKEDATE + | MAKETIME + | MICROSECOND + | MINUTE + | MONTH + | MONTHNAME + | NOW + | PERIOD_ADD + | PERIOD_DIFF + | QUARTER + | SECOND + | SUBDATE + | SYSDATE + | TIME + | TIME_TO_SEC + | TIMESTAMP + | TO_DAYS + | UNIX_TIMESTAMP + | UTC_DATE + | UTC_TIME + | UTC_TIMESTAMP + | WEEK + | YEAR ; /** condition function return boolean value */ @@ -571,7 +600,6 @@ keywordsCanBeId | TIMESTAMP | DATE | TIME | FIRST | LAST | timespanUnit | SPAN - | constantFunctionName | dateAndTimeFunctionBase | textFunctionBase | mathematicalFunctionBase diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 4430820081..115bcf3cd8 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -14,7 +14,6 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BySpanClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CompareExprContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConstantFunctionContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ConvertedDataTypeContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.CountAllFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.DataTypeFunctionCallContext; @@ -23,7 +22,6 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalClauseContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.EvalFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FieldExpressionContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.FunctionArgsContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsQualifiedNameContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.IdentsAsWildcardQualifiedNameContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.InExprContext; @@ -61,7 +59,6 @@ import org.opensearch.sql.ast.expression.Argument; import org.opensearch.sql.ast.expression.Cast; import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.ConstantFunction; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Field; import org.opensearch.sql.ast.expression.Function; @@ -232,8 +229,8 @@ public UnresolvedExpression visitTakeAggFunctionCall( @Override public UnresolvedExpression visitBooleanFunctionCall(BooleanFunctionCallContext ctx) { final String functionName = ctx.conditionFunctionBase().getText(); - return visitFunction(FUNCTION_NAME_MAPPING.getOrDefault(functionName, functionName), - ctx.functionArgs()); + return buildFunction(FUNCTION_NAME_MAPPING.getOrDefault(functionName, functionName), + ctx.functionArgs().functionArg()); } /** @@ -241,7 +238,7 @@ public UnresolvedExpression visitBooleanFunctionCall(BooleanFunctionCallContext */ @Override public UnresolvedExpression visitEvalFunctionCall(EvalFunctionCallContext ctx) { - return visitFunction(ctx.evalFunctionName().getText(), ctx.functionArgs()); + return buildFunction(ctx.evalFunctionName().getText(), ctx.functionArgs().functionArg()); } /** @@ -257,23 +254,11 @@ public UnresolvedExpression visitConvertedDataType(ConvertedDataTypeContext ctx) return AstDSL.stringLiteral(ctx.getText()); } - public UnresolvedExpression visitConstantFunction(ConstantFunctionContext ctx) { - return visitConstantFunction(ctx.constantFunctionName().getText(), - ctx.functionArgs()); - } - - private UnresolvedExpression visitConstantFunction(String functionName, - FunctionArgsContext args) { - return new ConstantFunction(functionName, args.functionArg() - .stream() - .map(this::visitFunctionArg) - .collect(Collectors.toList())); - } - - private Function visitFunction(String functionName, FunctionArgsContext args) { + private Function buildFunction(String functionName, + List args) { return new Function( functionName, - args.functionArg() + args .stream() .map(this::visitFunctionArg) .collect(Collectors.toList()) @@ -372,7 +357,7 @@ private QualifiedName visitIdentifiers(List ctx) { } private List singleFieldRelevanceArguments( - OpenSearchPPLParser.SingleFieldRelevanceFunctionContext ctx) { + SingleFieldRelevanceFunctionContext ctx) { // all the arguments are defaulted to string values // to skip environment resolving and function signature resolving ImmutableList.Builder builder = ImmutableList.builder(); @@ -387,7 +372,7 @@ private List singleFieldRelevanceArguments( } private List multiFieldRelevanceArguments( - OpenSearchPPLParser.MultiFieldRelevanceFunctionContext ctx) { + MultiFieldRelevanceFunctionContext ctx) { // all the arguments are defaulted to string values // to skip environment resolving and function signature resolving ImmutableList.Builder builder = ImmutableList.builder(); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java index d301b7b918..a1b1ccaf14 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/PPLServiceTest.java @@ -28,6 +28,7 @@ import org.opensearch.sql.executor.QueryManager; import org.opensearch.sql.executor.QueryService; import org.opensearch.sql.executor.execution.QueryPlanFactory; +import org.opensearch.sql.expression.function.FunctionProperties; import org.opensearch.sql.ppl.config.PPLServiceConfig; import org.opensearch.sql.ppl.domain.PPLQueryRequest; import org.opensearch.sql.storage.StorageEngine; @@ -72,6 +73,7 @@ public void setUp() { context.registerBean(StorageEngine.class, () -> storageEngine); context.registerBean(ExecutionEngine.class, () -> executionEngine); context.registerBean(DataSourceService.class, () -> dataSourceService); + context.registerBean(FunctionProperties.class, FunctionProperties::new); context.register(PPLServiceConfig.class); context.refresh(); pplService = context.getBean(PPLService.class); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstNowLikeFunctionTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstNowLikeFunctionTest.java index 1350305391..711e780f3b 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstNowLikeFunctionTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstNowLikeFunctionTest.java @@ -8,7 +8,6 @@ import static org.junit.Assert.assertEquals; import static org.opensearch.sql.ast.dsl.AstDSL.compare; -import static org.opensearch.sql.ast.dsl.AstDSL.constantFunction; import static org.opensearch.sql.ast.dsl.AstDSL.eval; import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; @@ -18,6 +17,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.relation; import java.util.List; +import org.junit.Assume; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -34,14 +34,11 @@ public class AstNowLikeFunctionTest { * @param name Function name * @param hasFsp Whether function has fsp argument * @param hasShortcut Whether function has shortcut (call without `()`) - * @param isConstantFunction Whether function has constant value */ - public AstNowLikeFunctionTest(String name, Boolean hasFsp, Boolean hasShortcut, - Boolean isConstantFunction) { + public AstNowLikeFunctionTest(String name, Boolean hasFsp, Boolean hasShortcut) { this.name = name; this.hasFsp = hasFsp; this.hasShortcut = hasShortcut; - this.isConstantFunction = isConstantFunction; } /** @@ -51,55 +48,67 @@ public AstNowLikeFunctionTest(String name, Boolean hasFsp, Boolean hasShortcut, @Parameterized.Parameters(name = "{0}") public static Iterable functionNames() { return List.of(new Object[][]{ - {"now", false, false, true}, - {"current_timestamp", false, false, true}, - {"localtimestamp", false, false, true}, - {"localtime", false, false, true}, - {"sysdate", true, false, false}, - {"curtime", false, false, true}, - {"current_time", false, false, true}, - {"curdate", false, false, true}, - {"current_date", false, false, true} + {"now", false, false }, + {"current_timestamp", false, false}, + {"localtimestamp", false, false}, + {"localtime", false, false}, + {"sysdate", true, false}, + {"curtime", false, false}, + {"current_time", false, false}, + {"curdate", false, false}, + {"current_date", false, false} }); } private final String name; - private final Boolean hasFsp; - private final Boolean hasShortcut; - private final Boolean isConstantFunction; + private final boolean hasFsp; + private final boolean hasShortcut; @Test - public void test_now_like_functions() { - for (var call : hasShortcut ? List.of(name, name + "()") : List.of(name + "()")) { - assertEqual("source=t | eval r=" + call, - eval( - relation("t"), - let( - field("r"), - (isConstantFunction ? constantFunction(name) : function(name)) - ) - )); - - assertEqual("search source=t | where a=" + call, - filter( - relation("t"), - compare("=", field("a"), - (isConstantFunction ? constantFunction(name) : function(name))) - ) - ); - } - // Unfortunately, only real functions (not ConstantFunction) might have `fsp` now. - if (hasFsp) { - assertEqual("search source=t | where a=" + name + "(0)", - filter( - relation("t"), - compare("=", field("a"), function(name, intLiteral(0))) - ) - ); - } + public void test_function_call_eval() { + assertEqual( + eval(relation("t"), let(field("r"), function(name))), + "source=t | eval r=" + name + "()" + ); } - protected void assertEqual(String query, Node expectedPlan) { + @Test + public void test_shortcut_eval() { + Assume.assumeTrue(hasShortcut); + assertEqual( + eval(relation("t"), let(field("r"), function(name))), + "source=t | eval r=" + name + ); + } + + @Test + public void test_function_call_where() { + assertEqual( + filter(relation("t"), compare("=", field("a"), function(name))), + "search source=t | where a=" + name + "()" + ); + } + + @Test + public void test_shortcut_where() { + Assume.assumeTrue(hasShortcut); + assertEqual( + filter(relation("t"), compare("=", field("a"), function(name))), + "search source=t | where a=" + name + ); + } + + @Test + public void test_function_call_fsp() { + Assume.assumeTrue(hasFsp); + assertEqual(filter( + relation("t"), + compare("=", field("a"), function(name, intLiteral(0))) + ), "search source=t | where a=" + name + "(0)" + ); + } + + protected void assertEqual(Node expectedPlan, String query) { Node actualPlan = plan(query); assertEquals(expectedPlan, actualPlan); } diff --git a/prometheus/src/main/java/org/opensearch/sql/prometheus/functions/resolver/QueryRangeTableFunctionResolver.java b/prometheus/src/main/java/org/opensearch/sql/prometheus/functions/resolver/QueryRangeTableFunctionResolver.java index 63d41fb1d8..a1d72f98c3 100644 --- a/prometheus/src/main/java/org/opensearch/sql/prometheus/functions/resolver/QueryRangeTableFunctionResolver.java +++ b/prometheus/src/main/java/org/opensearch/sql/prometheus/functions/resolver/QueryRangeTableFunctionResolver.java @@ -46,7 +46,7 @@ public Pair resolve(FunctionSignature unreso new FunctionSignature(functionName, List.of(STRING, LONG, LONG, LONG)); final List argumentNames = List.of(QUERY, STARTTIME, ENDTIME, STEP); - FunctionBuilder functionBuilder = arguments -> { + FunctionBuilder functionBuilder = (functionProperties, arguments) -> { Boolean argumentsPassedByName = arguments.stream() .noneMatch(arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); Boolean argumentsPassedByPosition = arguments.stream() diff --git a/prometheus/src/test/java/org/opensearch/sql/prometheus/functions/QueryRangeTableFunctionResolverTest.java b/prometheus/src/test/java/org/opensearch/sql/prometheus/functions/QueryRangeTableFunctionResolverTest.java index caca48f834..06f003c9b6 100644 --- a/prometheus/src/test/java/org/opensearch/sql/prometheus/functions/QueryRangeTableFunctionResolverTest.java +++ b/prometheus/src/test/java/org/opensearch/sql/prometheus/functions/QueryRangeTableFunctionResolverTest.java @@ -26,6 +26,7 @@ import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.function.FunctionBuilder; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; import org.opensearch.sql.expression.function.FunctionSignature; import org.opensearch.sql.expression.function.TableFunctionImplementation; import org.opensearch.sql.prometheus.client.PrometheusClient; @@ -40,6 +41,9 @@ class QueryRangeTableFunctionResolverTest { @Mock private PrometheusClient client; + @Mock + private FunctionProperties functionProperties; + @Test void testResolve() { QueryRangeTableFunctionResolver queryRangeTableFunctionResolver @@ -59,7 +63,7 @@ void testResolve() { assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); FunctionBuilder functionBuilder = resolution.getValue(); TableFunctionImplementation functionImplementation - = (TableFunctionImplementation) functionBuilder.apply(expressions); + = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); assertTrue(functionImplementation instanceof QueryRangeFunctionImplementation); PrometheusMetricTable prometheusMetricTable = (PrometheusMetricTable) functionImplementation.applyArguments(); @@ -93,7 +97,7 @@ void testArgumentsPassedByPosition() { assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); FunctionBuilder functionBuilder = resolution.getValue(); TableFunctionImplementation functionImplementation - = (TableFunctionImplementation) functionBuilder.apply(expressions); + = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); assertTrue(functionImplementation instanceof QueryRangeFunctionImplementation); PrometheusMetricTable prometheusMetricTable = (PrometheusMetricTable) functionImplementation.applyArguments(); @@ -128,7 +132,7 @@ void testArgumentsPassedByNameWithDifferentOrder() { assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); FunctionBuilder functionBuilder = resolution.getValue(); TableFunctionImplementation functionImplementation - = (TableFunctionImplementation) functionBuilder.apply(expressions); + = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); assertTrue(functionImplementation instanceof QueryRangeFunctionImplementation); PrometheusMetricTable prometheusMetricTable = (PrometheusMetricTable) functionImplementation.applyArguments(); @@ -160,7 +164,7 @@ void testMixedArgumentTypes() { assertEquals(functionName, queryRangeTableFunctionResolver.getFunctionName()); assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); SemanticCheckException exception = assertThrows(SemanticCheckException.class, - () -> resolution.getValue().apply(expressions)); + () -> resolution.getValue().apply(functionProperties, expressions)); assertEquals("Arguments should be either passed by name or position", exception.getMessage()); } @@ -182,7 +186,7 @@ void testWrongArgumentsSizeWhenPassedByName() { assertEquals(functionName, queryRangeTableFunctionResolver.getFunctionName()); assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); SemanticCheckException exception = assertThrows(SemanticCheckException.class, - () -> resolution.getValue().apply(expressions)); + () -> resolution.getValue().apply(functionProperties, expressions)); assertEquals("Missing arguments:[endtime,starttime]", exception.getMessage()); } @@ -204,7 +208,7 @@ void testWrongArgumentsSizeWhenPassedByPosition() { assertEquals(functionName, queryRangeTableFunctionResolver.getFunctionName()); assertEquals(List.of(STRING, LONG, LONG, LONG), resolution.getKey().getParamTypeList()); SemanticCheckException exception = assertThrows(SemanticCheckException.class, - () -> resolution.getValue().apply(expressions)); + () -> resolution.getValue().apply(functionProperties, expressions)); assertEquals("Missing arguments:[endtime,step]", exception.getMessage()); } diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index e4c6f0a5e3..b17c25261a 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -239,6 +239,18 @@ timestampLiteral : TIMESTAMP timestamp=stringLiteral ; +// Actually, these constants are shortcuts to the corresponding functions +datetimeConstantLiteral + : CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | LOCALTIME + | LOCALTIMESTAMP + | UTC_TIMESTAMP + | UTC_DATE + | UTC_TIME + ; + intervalLiteral : INTERVAL expression intervalUnit ; @@ -301,12 +313,8 @@ functionCall | aggregateFunction (orderByClause)? filterClause #filteredAggregationFunctionCall | relevanceFunction #relevanceFunctionCall | highlightFunction #highlightFunctionCall - | constantFunction #constantFunctionCall ; -constantFunction - : constantFunctionName LR_BRACKET functionArgs RR_BRACKET - ; highlightFunction : HIGHLIGHT LR_BRACKET relevanceField (COMMA highlightArg)* RR_BRACKET @@ -393,17 +401,44 @@ trigonometricFunctionName ; dateTimeFunctionName - : ADDDATE | CONVERT_TZ | DATE | DATE_ADD | DATE_FORMAT | DATE_SUB - | DATETIME | DAY | DAYNAME | DAYOFMONTH | DAYOFWEEK | DAYOFYEAR | FROM_DAYS | FROM_UNIXTIME - | HOUR | MAKEDATE | MAKETIME | MICROSECOND | MINUTE | MONTH | MONTHNAME | PERIOD_ADD - | PERIOD_DIFF | QUARTER | SECOND | SUBDATE | SYSDATE | TIME | TIME_TO_SEC - | TIMESTAMP | TO_DAYS | UNIX_TIMESTAMP | WEEK | YEAR - ; - -// Functions which value could be cached in scope of a single query -constantFunctionName - : CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP | LOCALTIME | LOCALTIMESTAMP | UTC_TIMESTAMP | UTC_DATE | UTC_TIME - | CURDATE | CURTIME | NOW + : datetimeConstantLiteral + | ADDDATE + | CONVERT_TZ + | CURDATE + | CURTIME + | DATE + | DATE_ADD + | DATE_FORMAT + | DATE_SUB + | DATETIME + | DAY + | DAYNAME + | DAYOFMONTH + | DAYOFWEEK + | DAYOFYEAR + | FROM_DAYS + | FROM_UNIXTIME + | HOUR + | MAKEDATE + | MAKETIME + | MICROSECOND + | MINUTE + | MONTH + | MONTHNAME + | NOW + | PERIOD_ADD + | PERIOD_DIFF + | QUARTER + | SECOND + | SUBDATE + | SYSDATE + | TIME + | TIME_TO_SEC + | TIMESTAMP + | TO_DAYS + | UNIX_TIMESTAMP + | WEEK + | YEAR ; textFunctionName diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 131b6d9116..18efef039f 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -18,7 +18,6 @@ import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.CaseFuncAlternativeContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.CaseFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ColumnFilterContext; -import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ConstantFunctionContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ConvertedDataTypeContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.CountStarFunctionCallContext; import static org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.DataTypeFunctionCallContext; @@ -63,7 +62,6 @@ import org.opensearch.sql.ast.expression.And; import org.opensearch.sql.ast.expression.Case; import org.opensearch.sql.ast.expression.Cast; -import org.opensearch.sql.ast.expression.ConstantFunction; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Function; import org.opensearch.sql.ast.expression.HighlightFunction; @@ -83,7 +81,6 @@ import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AndExpressionContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ColumnNameContext; -import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.FunctionArgsContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IdentContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.IntervalLiteralContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.NestedExpressionAtomContext; @@ -131,7 +128,7 @@ public UnresolvedExpression visitNestedExpressionAtom(NestedExpressionAtomContex @Override public UnresolvedExpression visitScalarFunctionCall(ScalarFunctionCallContext ctx) { - return visitFunction(ctx.scalarFunctionName().getText(), ctx.functionArgs()); + return buildFunction(ctx.scalarFunctionName().getText(), ctx.functionArgs().functionArg()); } @Override @@ -206,7 +203,7 @@ public UnresolvedExpression visitWindowFunctionClause(WindowFunctionClauseContex @Override public UnresolvedExpression visitScalarWindowFunction(ScalarWindowFunctionContext ctx) { - return visitFunction(ctx.functionName.getText(), ctx.functionArgs()); + return buildFunction(ctx.functionName.getText(), ctx.functionArgs().functionArg()); } @Override @@ -404,30 +401,17 @@ public UnresolvedExpression visitMultiFieldRelevanceFunction( multiFieldRelevanceArguments(ctx)); } - private Function visitFunction(String functionName, FunctionArgsContext args) { + private Function buildFunction(String functionName, + List arg) { return new Function( functionName, - args.functionArg() + arg .stream() .map(this::visitFunctionArg) .collect(Collectors.toList()) ); } - @Override - public UnresolvedExpression visitConstantFunction(ConstantFunctionContext ctx) { - return visitConstantFunction(ctx.constantFunctionName().getText(), - ctx.functionArgs()); - } - - private UnresolvedExpression visitConstantFunction(String functionName, - FunctionArgsContext args) { - return new ConstantFunction(functionName, args.functionArg() - .stream() - .map(this::visitFunctionArg) - .collect(Collectors.toList())); - } - private QualifiedName visitIdentifiers(List identifiers) { return new QualifiedName( identifiers.stream() diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java index a955399c4d..2aed4f2834 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTest.java @@ -14,7 +14,6 @@ import static org.opensearch.sql.ast.dsl.AstDSL.alias; import static org.opensearch.sql.ast.dsl.AstDSL.argument; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; -import static org.opensearch.sql.ast.dsl.AstDSL.constantFunction; import static org.opensearch.sql.ast.dsl.AstDSL.doubleLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; @@ -50,12 +49,7 @@ import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.sql.antlr.SQLSyntaxParser; -class AstBuilderTest { - - /** - * SQL syntax parser that helps prepare parse tree as AstBuilder input. - */ - private final SQLSyntaxParser parser = new SQLSyntaxParser(); +class AstBuilderTest extends AstBuilderTestBase { @Test public void can_build_select_literals() { @@ -679,60 +673,6 @@ public void can_build_limit_clause_with_offset() { buildAST("SELECT name FROM test LIMIT 5, 10")); } - private static Stream nowLikeFunctionsData() { - return Stream.of( - Arguments.of("now", false, false, true), - Arguments.of("current_timestamp", false, false, true), - Arguments.of("localtimestamp", false, false, true), - Arguments.of("localtime", false, false, true), - Arguments.of("sysdate", true, false, false), - Arguments.of("curtime", false, false, true), - Arguments.of("current_time", false, false, true), - Arguments.of("curdate", false, false, true), - Arguments.of("current_date", false, false, true) - ); - } - - @ParameterizedTest(name = "{0}") - @MethodSource("nowLikeFunctionsData") - public void test_now_like_functions(String name, Boolean hasFsp, Boolean hasShortcut, - Boolean isConstantFunction) { - for (var call : hasShortcut ? List.of(name, name + "()") : List.of(name + "()")) { - assertEquals( - project( - values(emptyList()), - alias(call, (isConstantFunction ? constantFunction(name) : function(name))) - ), - buildAST("SELECT " + call) - ); - - assertEquals( - project( - filter( - relation("test"), - function( - "=", - qualifiedName("data"), - (isConstantFunction ? constantFunction(name) : function(name))) - ), - AllFields.of() - ), - buildAST("SELECT * FROM test WHERE data = " + call) - ); - } - - // Unfortunately, only real functions (not ConstantFunction) might have `fsp` now. - if (hasFsp) { - assertEquals( - project( - values(emptyList()), - alias(name + "(0)", function(name, intLiteral(0))) - ), - buildAST("SELECT " + name + "(0)") - ); - } - } - @Test public void can_build_qualified_name_highlight() { Map args = new HashMap<>(); @@ -769,8 +709,4 @@ public void can_build_string_literal_highlight() { ); } - private UnresolvedPlan buildAST(String query) { - ParseTree parseTree = parser.parse(query); - return parseTree.accept(new AstBuilder(query)); - } } diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTestBase.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTestBase.java new file mode 100644 index 0000000000..2161eb5b1a --- /dev/null +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstBuilderTestBase.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql.parser; + +import org.antlr.v4.runtime.tree.ParseTree; +import org.opensearch.sql.ast.tree.UnresolvedPlan; +import org.opensearch.sql.sql.antlr.SQLSyntaxParser; + +public class AstBuilderTestBase { + /** + * SQL syntax parser that helps prepare parse tree as AstBuilder input. + */ + private final SQLSyntaxParser parser = new SQLSyntaxParser(); + + protected UnresolvedPlan buildAST(String query) { + ParseTree parseTree = parser.parse(query); + return parseTree.accept(new AstBuilder(query)); + } +} diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstNowLikeFunctionTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstNowLikeFunctionTest.java new file mode 100644 index 0000000000..19b48ca0bd --- /dev/null +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstNowLikeFunctionTest.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql.parser; + +import static java.util.Collections.emptyList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.opensearch.sql.ast.dsl.AstDSL.alias; +import static org.opensearch.sql.ast.dsl.AstDSL.filter; +import static org.opensearch.sql.ast.dsl.AstDSL.function; +import static org.opensearch.sql.ast.dsl.AstDSL.intLiteral; +import static org.opensearch.sql.ast.dsl.AstDSL.project; +import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; +import static org.opensearch.sql.ast.dsl.AstDSL.relation; +import static org.opensearch.sql.ast.dsl.AstDSL.values; + +import java.util.stream.Stream; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.ast.expression.AllFields; + +class AstNowLikeFunctionTest extends AstBuilderTestBase { + + private static Stream allFunctions() { + return Stream.of("curdate", + "current_date", + "current_time", + "current_timestamp", + "curtime", + "localtimestamp", + "localtime", + "now", + "sysdate") + .map(Arguments::of); + } + + private static Stream supportFsp() { + return Stream.of("sysdate") + .map(Arguments::of); + } + + private static Stream supportShortcut() { + return Stream.of("current_date", + "current_time", + "current_timestamp", + "localtimestamp", + "localtime") + .map(Arguments::of); + } + + @ParameterizedTest + @MethodSource("allFunctions") + void project_call(String name) { + String call = name + "()"; + assertEquals( + project( + values(emptyList()), + alias(call, function(name)) + ), + buildAST("SELECT " + call) + ); + } + + @ParameterizedTest + @MethodSource("allFunctions") + void filter_call(String name) { + String call = name + "()"; + assertEquals( + project( + filter( + relation("test"), + function( + "=", + qualifiedName("data"), + function(name)) + ), + AllFields.of() + ), + buildAST("SELECT * FROM test WHERE data = " + call) + ); + } + + + @ParameterizedTest + @MethodSource("supportFsp") + void fsp(String name) { + assertEquals( + project( + values(emptyList()), + alias(name + "(0)", function(name, intLiteral(0))) + ), + buildAST("SELECT " + name + "(0)") + ); + } +}