diff --git a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java index a3ba9b1b6b..feae0c8fe0 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/ExpressionAnalyzer.java @@ -35,6 +35,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.PositionFunction; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.expression.Span; @@ -213,6 +214,13 @@ public Expression visitHighlightFunction(HighlightFunction node, AnalysisContext return new HighlightExpression(expr); } + @Override + public Expression visitPositionFunction(PositionFunction node, AnalysisContext context) { + Expression stringPatternExpr = node.getStringPatternExpr().accept(this, context); + Expression searchStringExpr = node.getSearchStringExpr().accept(this, context); + return DSL.position(stringPatternExpr, searchStringExpr); + } + @Override public Expression visitIn(In node, AnalysisContext context) { return visitIn(node.getField(), node.getValueList(), context); diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index 53ff93eec1..e515de8666 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -27,6 +27,7 @@ import org.opensearch.sql.ast.expression.Map; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.PositionFunction; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.expression.Span; @@ -278,6 +279,10 @@ public T visitHighlightFunction(HighlightFunction node, C context) { return visitChildren(node, context); } + public T visitPositionFunction(PositionFunction node, C context) { + return visitChildren(node, context); + } + public T visitStatement(Statement node, C context) { return visit(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 24ada1fd92..1b8c76bd4d 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -33,6 +33,7 @@ import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; import org.opensearch.sql.ast.expression.ParseMethod; +import org.opensearch.sql.ast.expression.PositionFunction; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.SpanUnit; @@ -288,6 +289,11 @@ public UnresolvedExpression highlight(UnresolvedExpression fieldName, return new HighlightFunction(fieldName, arguments); } + public UnresolvedExpression position(UnresolvedExpression stringPatternExpr, + UnresolvedExpression searchStringExpr) { + return new PositionFunction(stringPatternExpr, searchStringExpr); + } + public UnresolvedExpression window(UnresolvedExpression function, List partitionByList, List> sortList) { diff --git a/core/src/main/java/org/opensearch/sql/ast/expression/PositionFunction.java b/core/src/main/java/org/opensearch/sql/ast/expression/PositionFunction.java new file mode 100644 index 0000000000..988237ebd0 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/expression/PositionFunction.java @@ -0,0 +1,39 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.expression; + +import java.util.Arrays; +import java.util.List; +import lombok.AllArgsConstructor; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; + + +/** + * Expression node of Position function. + */ +@AllArgsConstructor +@EqualsAndHashCode(callSuper = false) +@Getter +@ToString +public class PositionFunction extends UnresolvedExpression { + @Getter + private UnresolvedExpression stringPatternExpr; + @Getter + private UnresolvedExpression searchStringExpr; + + @Override + public List getChild() { + return Arrays.asList(stringPatternExpr, searchStringExpr); + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitPositionFunction(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 6486c4b676..96e9e45048 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -229,6 +229,10 @@ public static FunctionExpression cbrt(Expression... expressions) { return compile(BuiltinFunctionName.CBRT, expressions); } + public static FunctionExpression position(Expression... expressions) { + return compile(BuiltinFunctionName.POSITION, expressions); + } + public static FunctionExpression truncate(Expression... expressions) { return compile(BuiltinFunctionName.TRUNCATE, expressions); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index cc3db47982..69f86b7c53 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -153,23 +153,24 @@ public enum BuiltinFunctionName { /** * Text Functions. */ - SUBSTR(FunctionName.of("substr")), - SUBSTRING(FunctionName.of("substring")), - RTRIM(FunctionName.of("rtrim")), - LTRIM(FunctionName.of("ltrim")), - TRIM(FunctionName.of("trim")), - UPPER(FunctionName.of("upper")), - LOWER(FunctionName.of("lower")), - REGEXP(FunctionName.of("regexp")), + ASCII(FunctionName.of("ascii")), CONCAT(FunctionName.of("concat")), CONCAT_WS(FunctionName.of("concat_ws")), - LENGTH(FunctionName.of("length")), - STRCMP(FunctionName.of("strcmp")), - RIGHT(FunctionName.of("right")), LEFT(FunctionName.of("left")), - ASCII(FunctionName.of("ascii")), + LENGTH(FunctionName.of("length")), LOCATE(FunctionName.of("locate")), + LOWER(FunctionName.of("lower")), + LTRIM(FunctionName.of("ltrim")), + POSITION(FunctionName.of("position")), + REGEXP(FunctionName.of("regexp")), REPLACE(FunctionName.of("replace")), + RIGHT(FunctionName.of("right")), + RTRIM(FunctionName.of("rtrim")), + STRCMP(FunctionName.of("strcmp")), + SUBSTR(FunctionName.of("substr")), + SUBSTRING(FunctionName.of("substring")), + TRIM(FunctionName.of("trim")), + UPPER(FunctionName.of("upper")), /** * NULL Test. diff --git a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java index 8035728d19..b51d6a2716 100644 --- a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java @@ -39,22 +39,23 @@ public class TextFunction { * @param repository {@link BuiltinFunctionRepository}. */ public void register(BuiltinFunctionRepository repository) { - repository.register(substr()); - repository.register(substring()); - repository.register(ltrim()); - repository.register(rtrim()); - repository.register(trim()); - repository.register(lower()); - repository.register(upper()); + repository.register(ascii()); repository.register(concat()); repository.register(concat_ws()); - repository.register(length()); - repository.register(strcmp()); - repository.register(right()); repository.register(left()); - repository.register(ascii()); + repository.register(length()); repository.register(locate()); + repository.register(lower()); + repository.register(ltrim()); + repository.register(position()); repository.register(replace()); + repository.register(right()); + repository.register(rtrim()); + repository.register(strcmp()); + repository.register(substr()); + repository.register(substring()); + repository.register(trim()); + repository.register(upper()); } /** @@ -241,6 +242,18 @@ private DefaultFunctionResolver locate() { TextFunction::exprLocate), INTEGER, STRING, STRING, INTEGER)); } + /** + * Returns the position of the first occurrence of a substring in a string starting from 1. + * Returns 0 if substring is not in string. + * Returns NULL if any argument is NULL. + * Supports following signature: + * (STRING IN STRING) -> INTEGER + */ + private DefaultFunctionResolver position() { + return define(BuiltinFunctionName.POSITION.getName(), + impl(nullMissingHandling(TextFunction::exprPosition), INTEGER, STRING, STRING)); + } + /** * REPLACE(str, from_str, to_str) returns the string str with all occurrences of * the string from_str replaced by the string to_str. @@ -313,6 +326,10 @@ private static ExprValue exprLocate(ExprValue subStr, ExprValue str, ExprValue p str.stringValue().indexOf(subStr.stringValue(), pos.integerValue() - 1) + 1); } + private static ExprValue exprPosition(ExprValue subStr, ExprValue str) { + return exprLocate(subStr, str); + } + private static ExprValue exprReplace(ExprValue str, ExprValue from, ExprValue to) { return new ExprStringValue(str.stringValue().replaceAll(from.stringValue(), to.stringValue())); } diff --git a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java index 913593add3..2f97c6480e 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java @@ -16,6 +16,7 @@ import org.opensearch.sql.ast.expression.Alias; import org.opensearch.sql.ast.expression.HighlightFunction; import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.PositionFunction; import org.opensearch.sql.expression.NamedExpression; import org.springframework.context.annotation.Configuration; import org.springframework.test.context.ContextConfiguration; @@ -48,4 +49,14 @@ void visit_highlight() { NamedExpression analyze = analyzer.analyze(alias, analysisContext); assertEquals("highlight(fieldA)", analyze.getNameOrAlias()); } + + @Test + void visit_position() { + Alias alias = AstDSL.alias("position(fieldA IN fieldB)", + new PositionFunction(AstDSL.stringLiteral("fieldA"), AstDSL.stringLiteral("fieldB"))); + NamedExpressionAnalyzer analyzer = new NamedExpressionAnalyzer(expressionAnalyzer); + + NamedExpression analyze = analyzer.analyze(alias, analysisContext); + assertEquals("position(fieldA IN fieldB)", analyze.getNameOrAlias()); + } } diff --git a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java index 502fe70ec8..5e32678b94 100644 --- a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java @@ -377,6 +377,26 @@ void locate() { DSL.locate(missingRef, DSL.literal("hello"), DSL.literal(1)))); } + @Test + void position() { + FunctionExpression expression = DSL.position( + DSL.literal("world"), + DSL.literal("helloworldworld")); + assertEquals(INTEGER, expression.type()); + assertEquals(6, eval(expression).integerValue()); + + expression = DSL.position( + DSL.literal("abc"), + DSL.literal("hello world")); + assertEquals(INTEGER, expression.type()); + assertEquals(0, eval(expression).integerValue()); + + when(nullRef.type()).thenReturn(STRING); + assertEquals(nullValue(), eval(DSL.position(nullRef, DSL.literal("hello")))); + when(missingRef.type()).thenReturn(STRING); + assertEquals(missingValue(), eval(DSL.position(missingRef, DSL.literal("hello")))); + } + @Test void replace() { FunctionExpression expression = DSL.replace( diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 9c26427143..3d084152c1 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -2331,6 +2331,31 @@ Example:: +---------------------+---------------------+ +POSITION +------ + +Description +>>>>>>>>>>> + +Usage: The syntax POSITION(substr IN str) returns the position of the first occurrence of substring substr in string str. Returns 0 if substr is not in str. Returns NULL if any argument is NULL. + +Argument type: STRING, STRING, INTEGER + +Return type integer: + +(STRING IN STRING) -> INTEGER + +Example:: + + os> SELECT POSITION('world' IN 'helloworld') + fetched rows / total rows = 1/1 + +-------------------------------------+---------------------------------------+ + | POSITION('world' IN 'helloworld') | POSITION('invalid' IN 'helloworld') | + |-------------------------------------+---------------------------------------| + | 6 | 0 | + +-------------------------------------+---------------------------------------+ + + REPLACE ------- diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/PositionFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/PositionFunctionIT.java new file mode 100644 index 0000000000..f51a3a0977 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/PositionFunctionIT.java @@ -0,0 +1,129 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import org.json.JSONObject; +import org.junit.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.TestsConstants; + +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.schema; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; +import static org.opensearch.sql.util.MatcherUtils.verifySchema; + +public class PositionFunctionIT extends SQLIntegTestCase { + + @Override + protected void init() throws Exception { + loadIndex(Index.PEOPLE2); + loadIndex(Index.CALCS); + } + + @Test + public void position_function_test() { + String query = "SELECT firstname, position('a' IN firstname) FROM %s"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_PEOPLE2)); + + verifySchema(response, schema("firstname", null, "keyword"), + schema("position('a' IN firstname)", null, "integer")); + assertEquals(12, response.getInt("total")); + + verifyDataRows(response, + rows("Daenerys", 2), rows("Hattie", 2), + rows("Nanette", 2), rows("Dale", 2), + rows("Elinor", 0), rows("Virginia", 8), + rows("Dillard", 5), rows("Mcgee", 0), + rows("Aurelia", 7), rows("Fulton", 0), + rows("Burton", 0), rows("Josie", 0)); + } + + @Test + public void position_function_with_nulls_test() { + String query = "SELECT str2, position('ee' IN str2) FROM %s"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_CALCS)); + + verifySchema(response, schema("str2", null, "keyword"), + schema("position('ee' IN str2)", null, "integer")); + assertEquals(17, response.getInt("total")); + + verifyDataRows(response, + rows("one", 0), rows("two", 0), + rows("three", 4), rows(null, null), + rows("five", 0), rows("six", 0), + rows(null, null), rows("eight", 0), + rows("nine", 0), rows("ten", 0), + rows("eleven", 0), rows("twelve", 0), + rows(null, null), rows("fourteen", 6), + rows("fifteen", 5), rows("sixteen", 5), + rows(null, null)); + } + + @Test + public void position_function_with_string_literals_test() { + String query = "SELECT position('world' IN 'hello world')"; + JSONObject response = executeJdbcRequest(query); + + verifySchema(response, schema("position('world' IN 'hello world')", null, "integer")); + assertEquals(1, response.getInt("total")); + + verifyDataRows(response, rows(7)); + } + + @Test + public void position_function_with_only_fields_as_args_test() { + String query = "SELECT position(str3 IN str2) FROM %s WHERE str2 IN ('one', 'two', 'three')"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_CALCS)); + + verifySchema(response, schema("position(str3 IN str2)", null, "integer")); + assertEquals(3, response.getInt("total")); + + verifyDataRows(response, rows(3), rows(0), rows(4)); + } + + @Test + public void position_function_with_function_as_arg_test() { + String query = "SELECT position(upper(str3) IN str1) FROM %s WHERE str1 LIKE 'BINDING SUPPLIES'"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_CALCS)); + + verifySchema(response, schema("position(upper(str3) IN str1)", null, "integer")); + assertEquals(1, response.getInt("total")); + + verifyDataRows(response, rows(15)); + } + + @Test + public void position_function_in_where_clause_test() { + String query = "SELECT str2 FROM %s WHERE position(str3 IN str2)=1"; + JSONObject response = executeJdbcRequest(String.format(query, TestsConstants.TEST_INDEX_CALCS)); + + verifySchema(response, schema("str2", null, "keyword")); + assertEquals(2, response.getInt("total")); + + verifyDataRows(response, rows("eight"), rows("eleven")); + } + + @Test + public void position_function_with_null_args_test() { + String query1 = "SELECT str2, position(null IN str2) FROM %s WHERE str2 IN ('one')"; + String query2 = "SELECT str2, position(str2 IN null) FROM %s WHERE str2 IN ('one')"; + JSONObject response1 = executeJdbcRequest(String.format(query1, TestsConstants.TEST_INDEX_CALCS)); + JSONObject response2 = executeJdbcRequest(String.format(query2, TestsConstants.TEST_INDEX_CALCS)); + + verifySchema(response1, + schema("str2", null, "keyword"), + schema("position(null IN str2)", null, "integer")); + assertEquals(1, response1.getInt("total")); + + verifySchema(response2, + schema("str2", null, "keyword"), + schema("position(str2 IN null)", null, "integer")); + assertEquals(1, response2.getInt("total")); + + verifyDataRows(response1, rows("one", null)); + verifyDataRows(response2, rows("one", null)); + } +} diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index 9e0a409401..b15363445e 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -234,6 +234,7 @@ NULLIF: 'NULLIF'; PERIOD_ADD: 'PERIOD_ADD'; PERIOD_DIFF: 'PERIOD_DIFF'; PI: 'PI'; +POSITION: 'POSITION'; POW: 'POW'; POWER: 'POWER'; RADIANS: 'RADIANS'; diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index b3fd29b342..620a1811fd 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -302,6 +302,7 @@ functionCall | relevanceFunction #relevanceFunctionCall | highlightFunction #highlightFunctionCall | constantFunction #constantFunctionCall + | positionFunction #positionFunctionCall ; constantFunction @@ -312,6 +313,10 @@ highlightFunction : HIGHLIGHT LR_BRACKET relevanceField (COMMA highlightArg)* RR_BRACKET ; +positionFunction + : POSITION LR_BRACKET functionArg IN functionArg RR_BRACKET + ; + scalarFunctionName : mathematicalFunctionName | dateTimeFunctionName diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 131b6d9116..7f6b52e87c 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -72,6 +72,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.PositionFunction; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.RelevanceFieldList; import org.opensearch.sql.ast.expression.UnresolvedArgument; @@ -428,6 +429,13 @@ private UnresolvedExpression visitConstantFunction(String functionName, .collect(Collectors.toList())); } + @Override + public UnresolvedExpression visitPositionFunction( + OpenSearchSQLParser.PositionFunctionContext ctx) { + return new PositionFunction(visitFunctionArg(ctx.functionArg(0)), + visitFunctionArg(ctx.functionArg(1))); + } + private QualifiedName visitIdentifiers(List identifiers) { return new QualifiedName( identifiers.stream() diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index cb00ea2f18..c1e3faad75 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -21,6 +21,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.not; import static org.opensearch.sql.ast.dsl.AstDSL.nullLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.or; +import static org.opensearch.sql.ast.dsl.AstDSL.position; import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.timeLiteral; @@ -328,6 +329,22 @@ public void canBuildQualifiedNameHighlightFunction() { ); } + @Test + public void canBuildPositionFunction() { + assertEquals( + position(AstDSL.qualifiedName("fieldA"), AstDSL.qualifiedName("fieldB")), + buildExprAst("position(fieldA IN fieldB)") + ); + } + + @Test + public void canBuildStringLiteralPositionFunction() { + assertEquals( + position(AstDSL.stringLiteral("fieldA"), AstDSL.stringLiteral("fieldB")), + buildExprAst("position(\"fieldA\" IN \"fieldB\")") + ); + } + @Test public void canBuildWindowFunctionWithoutOrderBy() { assertEquals(