From f53077044d814b32105557326c72e82fe38c501a Mon Sep 17 00:00:00 2001 From: Margarit Hakobyan Date: Wed, 7 Dec 2022 11:25:47 -0800 Subject: [PATCH 1/5] Add position() string function to PPL (#1147) * Add position() string function to PPL (#184) Signed-off-by: Margarit Hakobyan --- docs/user/ppl/functions/string.rst | 25 +++++ .../sql/ppl/PositionFunctionIT.java | 100 ++++++++++++++++++ ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 1 + ppl/src/main/antlr/OpenSearchPPLParser.g4 | 11 ++ .../sql/ppl/parser/AstExpressionBuilder.java | 10 ++ .../ppl/parser/AstExpressionBuilderTest.java | 13 +++ 6 files changed, 160 insertions(+) create mode 100644 integ-test/src/test/java/org/opensearch/sql/ppl/PositionFunctionIT.java diff --git a/docs/user/ppl/functions/string.rst b/docs/user/ppl/functions/string.rst index 116c28b0e2..b14acc88e0 100644 --- a/docs/user/ppl/functions/string.rst +++ b/docs/user/ppl/functions/string.rst @@ -150,6 +150,31 @@ Example:: +---------------------+---------------------+ +POSITION +------ + +Description +>>>>>>>>>>> + +Usage: The syntax POSITION(substr IN str) returns the position of the first occurrence of substring substr in string str. Returns 0 if substr is not in str. Returns NULL if any argument is NULL. + +Argument type: STRING, STRING + +Return type INTEGER + +(STRING IN STRING) -> INTEGER + +Example:: + + os> source=people | eval `POSITION('world' IN 'helloworld')` = POSITION('world' IN 'helloworld'), `POSITION('invalid' IN 'helloworld')`= POSITION('invalid' IN 'helloworld') | fields `POSITION('world' IN 'helloworld')`, `POSITION('invalid' IN 'helloworld')` + fetched rows / total rows = 1/1 + +-------------------------------------+---------------------------------------+ + | POSITION('world' IN 'helloworld') | POSITION('invalid' IN 'helloworld') | + |-------------------------------------+---------------------------------------| + | 6 | 0 | + +-------------------------------------+---------------------------------------+ + + RIGHT ----- diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/PositionFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/PositionFunctionIT.java new file mode 100644 index 0000000000..24319a0cb8 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/PositionFunctionIT.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import org.junit.Test; + +import java.io.IOException; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_CALCS; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; + +public class PositionFunctionIT extends PPLIntegTestCase { + @Override + public void init() throws IOException { + loadIndex(Index.CALCS); + } + + @Test + public void test_position_function() throws IOException { + String query = "source=" + TEST_INDEX_CALCS + + " | eval f=position('ON', str1) | fields f"; + + var result = executeQuery(query); + + assertEquals(17, result.getInt("total")); + verifyDataRows(result, + rows(7), rows(7), + rows(2), rows(0), + rows(0), rows(0), + rows(0), rows(0), + rows(0), rows(0), + rows(0), rows(0), + rows(0), rows(0), + rows(0), rows(0), + rows(0)); + } + + @Test + public void test_position_function_with_fields_only() throws IOException { + String query = "source=" + TEST_INDEX_CALCS + + " | eval f=position(str3 IN str2) | where str2 IN ('one', 'two', 'three')| fields f"; + + var result = executeQuery(query); + + assertEquals(3, result.getInt("total")); + verifyDataRows(result, rows(3), rows(0), rows(4)); + } + + @Test + public void test_position_function_with_string_literals() throws IOException { + String query = "source=" + TEST_INDEX_CALCS + + " | eval f=position('world' IN 'hello world') | where str2='one' | fields f"; + + var result = executeQuery(query); + + assertEquals(1, result.getInt("total")); + verifyDataRows(result, rows(7)); + } + + @Test + public void test_position_function_with_nulls() throws IOException { + String query = "source=" + TEST_INDEX_CALCS + + " | eval f=position('ee' IN str2) | where isnull(str2) | fields str2,f"; + + var result = executeQuery(query); + + assertEquals(4, result.getInt("total")); + verifyDataRows(result, + rows(null, null), + rows(null, null), + rows(null, null), + rows(null, null)); + } + + @Test + public void test_position_function_with_function_as_arg() throws IOException { + String query = "source=" + TEST_INDEX_CALCS + + " | eval f=position(upper(str3) IN str1) | where like(str1, 'BINDING SUPPLIES') | fields f"; + + var result = executeQuery(query); + + assertEquals(1, result.getInt("total")); + verifyDataRows(result, rows(15)); + } + + @Test + public void test_position_function_with_function_in_where_clause() throws IOException { + String query = "source=" + TEST_INDEX_CALCS + + " | where position(str3 IN str2)=1 | fields str2"; + + var result = executeQuery(query); + + assertEquals(2, result.getInt("total")); + verifyDataRows(result, rows("eight"), rows("eleven")); + } +} diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index a601a547ee..8c0340e7f1 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -215,6 +215,7 @@ LOG10: 'LOG10'; LOG2: 'LOG2'; MOD: 'MOD'; PI: 'PI'; +POSITION: 'POSITION'; POW: 'POW'; POWER: 'POWER'; RAND: 'RAND'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 76d8e38eff..6dba1ae783 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -258,6 +258,7 @@ valueExpression | LT_PRTHS left=valueExpression binaryOperator right=valueExpression RT_PRTHS #parentheticBinaryArithmetic | primaryExpression #valueExpressionDefault + | positionFunction #positionFunctionCall ; primaryExpression @@ -267,6 +268,10 @@ primaryExpression | literalValue ; +positionFunction + : positionFunctionName LT_PRTHS functionArg IN functionArg RT_PRTHS + ; + booleanExpression : booleanFunctionCall ; @@ -362,6 +367,7 @@ evalFunctionName | textFunctionBase | conditionFunctionBase | systemFunctionBase + | positionFunctionName ; functionArgs @@ -484,6 +490,10 @@ textFunctionBase | RIGHT | LEFT | ASCII | LOCATE | REPLACE ; +positionFunctionName + : POSITION + ; + /** operators */ comparisonOperator : EQUAL | NOT_EQUAL | LESS | NOT_LESS | GREATER | NOT_GREATER | REGEXP @@ -603,4 +613,5 @@ keywordsCanBeId | dateAndTimeFunctionBase | textFunctionBase | mathematicalFunctionBase + | positionFunctionName ; diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 115bcf3cd8..68608e23ad 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -9,6 +9,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.qualifiedName; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NOT_NULL; import static org.opensearch.sql.expression.function.BuiltinFunctionName.IS_NULL; +import static org.opensearch.sql.expression.function.BuiltinFunctionName.POSITION; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BinaryArithmeticContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanFunctionCallContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.BooleanLiteralContext; @@ -290,6 +291,15 @@ public UnresolvedExpression visitTableSource(TableSourceContext ctx) { } } + @Override + public UnresolvedExpression visitPositionFunction( + OpenSearchPPLParser.PositionFunctionContext ctx) { + return new Function( + POSITION.getName().getFunctionName(), + Arrays.asList(visitFunctionArg(ctx.functionArg(0)), + visitFunctionArg(ctx.functionArg(1)))); + } + /** * Literal and value. */ diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index e4048c5fe1..dbdfb71aa7 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -181,6 +181,19 @@ public void testEvalFunctionExprNoArgs() { )); } + @Test + public void testPositionFunctionExpr() { + assertEqual("source=t | eval f=position('substr' IN 'str')", + eval( + relation("t"), + let( + field("f"), + function("position", + stringLiteral("substr"), stringLiteral("str")) + ) + )); + } + @Test public void testEvalBinaryOperationExpr() { assertEqual("source=t | eval f=a+b", From 64a3794cb015430bc30199885f1d91c5b34d5654 Mon Sep 17 00:00:00 2001 From: Chen Dai Date: Wed, 7 Dec 2022 11:29:11 -0800 Subject: [PATCH 2/5] Improve pushdown optimization and logical to physical transformation (#1091) * Add new table scan builder and optimizer rules Signed-off-by: Chen Dai * Fix jacoco test coverage Signed-off-by: Chen Dai * Update javadoc with more details Signed-off-by: Chen Dai * Fix highlight pushdown issue Signed-off-by: Chen Dai * Rename new class more properly Signed-off-by: Chen Dai * Fix default sort by doc issue Signed-off-by: Chen Dai * Rename visit method and javadoc Signed-off-by: Chen Dai * Move table scan builder and optimize rule to read package Signed-off-by: Chen Dai * Fix sort push down issue Signed-off-by: Chen Dai * Move sortByFields to parent scan builder Signed-off-by: Chen Dai * Add back old test Signed-off-by: Chen Dai Signed-off-by: Chen Dai --- .../sql/planner/DefaultImplementor.java | 6 + .../logical/LogicalPlanNodeVisitor.java | 6 + .../optimizer/LogicalPlanOptimizer.java | 26 +- .../planner/optimizer/pattern/Patterns.java | 85 +++ .../rule/read/CreateTableScanBuilder.java | 51 ++ .../rule/read/TableScanPushDown.java | 129 ++++ .../org/opensearch/sql/storage/Table.java | 14 + .../sql/storage/read/TableScanBuilder.java | 111 ++++ .../sql/planner/DefaultImplementorTest.java | 15 + .../logical/LogicalPlanNodeVisitorTest.java | 11 + .../optimizer/LogicalPlanOptimizerTest.java | 187 +++++- .../optimizer/pattern/PatternsTest.java | 9 + .../logical/OpenSearchLogicalIndexAgg.java | 80 --- .../logical/OpenSearchLogicalIndexScan.java | 97 --- ...OpenSearchLogicalPlanOptimizerFactory.java | 47 -- .../logical/rule/MergeAggAndIndexScan.java | 57 -- .../logical/rule/MergeAggAndRelation.java | 54 -- .../logical/rule/MergeFilterAndRelation.java | 53 -- .../logical/rule/MergeLimitAndIndexScan.java | 54 -- .../logical/rule/MergeLimitAndRelation.java | 49 -- .../logical/rule/MergeSortAndIndexAgg.java | 82 --- .../logical/rule/MergeSortAndIndexScan.java | 70 -- .../logical/rule/MergeSortAndRelation.java | 53 -- .../logical/rule/OptimizationRuleUtils.java | 66 -- .../logical/rule/PushProjectAndIndexScan.java | 63 -- .../logical/rule/PushProjectAndRelation.java | 67 -- .../request/OpenSearchRequestBuilder.java | 17 + .../agg/CompositeAggregationParser.java | 2 + .../opensearch/response/agg/FilterParser.java | 2 + .../response/agg/MetricParserHelper.java | 2 + .../response/agg/SingleValueParser.java | 2 + .../opensearch/response/agg/StatsParser.java | 2 + .../response/agg/TopHitsParser.java | 2 + .../opensearch/storage/OpenSearchIndex.java | 115 +--- ...OpenSearchIndexScanAggregationBuilder.java | 92 +++ .../scan/OpenSearchIndexScanBuilder.java | 105 +++ .../scan/OpenSearchIndexScanQueryBuilder.java | 133 ++++ .../logical/OpenSearchLogicOptimizerTest.java | 576 ----------------- .../OpenSearchLogicalIndexScanTest.java | 24 - .../request/OpenSearchRequestBuilderTest.java | 177 ++++- .../OpenSearchDefaultImplementorTest.java | 54 +- .../storage/OpenSearchIndexTest.java | 273 +------- .../OpenSearchIndexScanOptimizationTest.java | 609 ++++++++++++++++++ .../sql/opensearch/utils/Utils.java | 131 ---- 44 files changed, 1810 insertions(+), 2050 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/CreateTableScanBuilder.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/TableScanPushDown.java create mode 100644 core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexAgg.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScan.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalPlanOptimizerFactory.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndIndexScan.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndRelation.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeFilterAndRelation.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndIndexScan.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndRelation.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexAgg.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexScan.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndRelation.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/OptimizationRuleUtils.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndIndexScan.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndRelation.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java delete mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java delete mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScanTest.java create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index 9f2c2c5fa8..4a5276418d 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -34,6 +34,7 @@ import org.opensearch.sql.planner.physical.SortOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; /** * Default implementor for implementing logical to physical translation. "Default" here means all @@ -123,6 +124,11 @@ public PhysicalPlan visitLimit(LogicalLimit node, C context) { return new LimitOperator(visitChild(node, context), node.getLimit(), node.getOffset()); } + @Override + public PhysicalPlan visitTableScanBuilder(TableScanBuilder plan, C context) { + return plan.build(); + } + @Override public PhysicalPlan visitRelation(LogicalRelation node, C context) { throw new UnsupportedOperationException("Storage engine is responsible for " diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index 28539562e7..0386eb6e2a 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -6,6 +6,8 @@ package org.opensearch.sql.planner.logical; +import org.opensearch.sql.storage.read.TableScanBuilder; + /** * The visitor of {@link LogicalPlan}. * @@ -22,6 +24,10 @@ public R visitRelation(LogicalRelation plan, C context) { return visitNode(plan, context); } + public R visitTableScanBuilder(TableScanBuilder plan, C context) { + return visitNode(plan, context); + } + public R visitFilter(LogicalFilter plan, C context) { return visitNode(plan, context); } diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java index 0e547df68d..f241e76993 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizer.java @@ -15,6 +15,8 @@ import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.optimizer.rule.MergeFilterAndFilter; import org.opensearch.sql.planner.optimizer.rule.PushFilterUnderSort; +import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; +import org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown; /** * {@link LogicalPlan} Optimizer. @@ -39,8 +41,21 @@ public LogicalPlanOptimizer(List> rules) { */ public static LogicalPlanOptimizer create() { return new LogicalPlanOptimizer(Arrays.asList( + /* + * Phase 1: Transformations that rely on relational algebra equivalence + */ new MergeFilterAndFilter(), - new PushFilterUnderSort())); + new PushFilterUnderSort(), + /* + * Phase 2: Transformations that rely on data source push down capability + */ + new CreateTableScanBuilder(), + TableScanPushDown.PUSH_DOWN_FILTER, + TableScanPushDown.PUSH_DOWN_AGGREGATION, + TableScanPushDown.PUSH_DOWN_SORT, + TableScanPushDown.PUSH_DOWN_LIMIT, + TableScanPushDown.PUSH_DOWN_HIGHLIGHT, + TableScanPushDown.PUSH_DOWN_PROJECT)); } /** @@ -63,7 +78,14 @@ private LogicalPlan internalOptimize(LogicalPlan plan) { Match match = DEFAULT_MATCHER.match(rule.pattern(), node); if (match.isPresent()) { node = rule.apply(match.value(), match.captures()); - done = false; + + // For new TableScanPushDown impl, pattern match doesn't necessarily cause + // push down to happen. So reiterate all rules against the node only if the node + // is actually replaced by any rule. + // TODO: may need to introduce fixed point or maximum iteration limit in future + if (node != match.value()) { + done = false; + } } } } diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java index 73d0f8c577..0ba478594a 100644 --- a/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/pattern/Patterns.java @@ -6,10 +6,22 @@ package org.opensearch.sql.planner.optimizer.pattern; +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Pattern; import com.facebook.presto.matching.Property; +import com.facebook.presto.matching.PropertyPattern; import java.util.Optional; import lombok.experimental.UtilityClass; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; /** * Pattern helper class. @@ -17,6 +29,55 @@ @UtilityClass public class Patterns { + /** + * Logical filter with a given pattern on inner field. + */ + public static Pattern filter(Pattern pattern) { + return Pattern.typeOf(LogicalFilter.class).with(source(pattern)); + } + + /** + * Logical aggregate operator with a given pattern on inner field. + */ + public static Pattern aggregate(Pattern pattern) { + return Pattern.typeOf(LogicalAggregation.class).with(source(pattern)); + } + + /** + * Logical sort operator with a given pattern on inner field. + */ + public static Pattern sort(Pattern pattern) { + return Pattern.typeOf(LogicalSort.class).with(source(pattern)); + } + + /** + * Logical limit operator with a given pattern on inner field. + */ + public static Pattern limit(Pattern pattern) { + return Pattern.typeOf(LogicalLimit.class).with(source(pattern)); + } + + /** + * Logical highlight operator with a given pattern on inner field. + */ + public static Pattern highlight(Pattern pattern) { + return Pattern.typeOf(LogicalHighlight.class).with(source(pattern)); + } + + /** + * Logical project operator with a given pattern on inner field. + */ + public static Pattern project(Pattern pattern) { + return Pattern.typeOf(LogicalProject.class).with(source(pattern)); + } + + /** + * Pattern for {@link TableScanBuilder} and capture it meanwhile. + */ + public static Pattern scanBuilder() { + return Pattern.typeOf(TableScanBuilder.class).capturedAs(Capture.newCapture()); + } + /** * LogicalPlan source {@link Property}. */ @@ -25,4 +86,28 @@ public static Property source() { ? Optional.of(plan.getChild().get(0)) : Optional.empty()); } + + /** + * Source (children field) with a given pattern. + */ + @SuppressWarnings("unchecked") + public static + PropertyPattern source(Pattern pattern) { + Property property = Property.optionalProperty("source", + plan -> plan.getChild().size() == 1 + ? Optional.of((T) plan.getChild().get(0)) + : Optional.empty()); + + return property.matching(pattern); + } + + /** + * Logical relation with table field. + */ + public static Property table() { + return Property.optionalProperty("table", + plan -> plan instanceof LogicalRelation + ? Optional.of(((LogicalRelation) plan).getTable()) + : Optional.empty()); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/CreateTableScanBuilder.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/CreateTableScanBuilder.java new file mode 100644 index 0000000000..dbe61ca8c3 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/CreateTableScanBuilder.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.optimizer.rule.read; + +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.table; + +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import lombok.Getter; +import lombok.experimental.Accessors; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.optimizer.Rule; +import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Rule that replace logical relation operator to {@link TableScanBuilder} for later + * push down optimization. All push down optimization rules that depends on table scan + * builder needs to run after this. + */ +public class CreateTableScanBuilder implements Rule { + + /** Capture the table inside matched logical relation operator. */ + private final Capture capture; + + /** Pattern that matches logical relation operator. */ + @Accessors(fluent = true) + @Getter + private final Pattern pattern; + + /** + * Construct create table scan builder rule. + */ + public CreateTableScanBuilder() { + this.capture = Capture.newCapture(); + this.pattern = Pattern.typeOf(LogicalRelation.class) + .with(table().capturedAs(capture)); + } + + @Override + public LogicalPlan apply(LogicalRelation plan, Captures captures) { + TableScanBuilder scanBuilder = captures.get(capture).createScanBuilder(); + // TODO: Remove this after Prometheus refactored to new table scan builder too + return (scanBuilder == null) ? plan : scanBuilder; + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/TableScanPushDown.java b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/TableScanPushDown.java new file mode 100644 index 0000000000..556a12bb34 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/optimizer/rule/read/TableScanPushDown.java @@ -0,0 +1,129 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.optimizer.rule.read; + +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.aggregate; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.filter; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.highlight; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.limit; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.project; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.scanBuilder; +import static org.opensearch.sql.planner.optimizer.pattern.Patterns.sort; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.TableScanPushDownBuilder.match; + +import com.facebook.presto.matching.Capture; +import com.facebook.presto.matching.Captures; +import com.facebook.presto.matching.Pattern; +import com.facebook.presto.matching.pattern.CapturePattern; +import com.facebook.presto.matching.pattern.WithPattern; +import java.util.function.BiFunction; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.optimizer.Rule; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Rule template for all table scan push down rules. Because all push down optimization rules + * have similar workflow in common, such as a pattern that match an operator on top of table scan + * builder, and action that eliminates the original operator if pushed down, this class helps + * remove redundant code and improve readability. + * + * @param logical plan node type + */ +public class TableScanPushDown implements Rule { + + /** Push down optimize rule for filtering condition. */ + public static final Rule PUSH_DOWN_FILTER = + match( + filter( + scanBuilder())) + .apply((filter, scanBuilder) -> scanBuilder.pushDownFilter(filter)); + + /** Push down optimize rule for aggregate operator. */ + public static final Rule PUSH_DOWN_AGGREGATION = + match( + aggregate( + scanBuilder())) + .apply((agg, scanBuilder) -> scanBuilder.pushDownAggregation(agg)); + + /** Push down optimize rule for sort operator. */ + public static final Rule PUSH_DOWN_SORT = + match( + sort( + scanBuilder())) + .apply((sort, scanBuilder) -> scanBuilder.pushDownSort(sort)); + + /** Push down optimize rule for limit operator. */ + public static final Rule PUSH_DOWN_LIMIT = + match( + limit( + scanBuilder())) + .apply((limit, scanBuilder) -> scanBuilder.pushDownLimit(limit)); + + public static final Rule PUSH_DOWN_PROJECT = + match( + project( + scanBuilder())) + .apply((project, scanBuilder) -> scanBuilder.pushDownProject(project)); + + public static final Rule PUSH_DOWN_HIGHLIGHT = + match( + highlight( + scanBuilder())) + .apply((highlight, scanBuilder) -> scanBuilder.pushDownHighlight(highlight)); + + + /** Pattern that matches a plan node. */ + private final WithPattern pattern; + + /** Capture table scan builder inside a plan node. */ + private final Capture capture; + + /** Push down function applied to the plan node and captured table scan builder. */ + private final BiFunction pushDownFunction; + + + @SuppressWarnings("unchecked") + private TableScanPushDown(WithPattern pattern, + BiFunction pushDownFunction) { + this.pattern = pattern; + this.capture = ((CapturePattern) pattern.getPattern()).capture(); + this.pushDownFunction = pushDownFunction; + } + + @Override + public Pattern pattern() { + return pattern; + } + + @Override + public LogicalPlan apply(T plan, Captures captures) { + TableScanBuilder scanBuilder = captures.get(capture); + if (pushDownFunction.apply(plan, scanBuilder)) { + return scanBuilder; + } + return plan; + } + + /** + * Custom builder class other than generated by Lombok to provide more readable code. + */ + static class TableScanPushDownBuilder { + + private WithPattern pattern; + + public static + TableScanPushDownBuilder match(Pattern pattern) { + TableScanPushDownBuilder builder = new TableScanPushDownBuilder<>(); + builder.pattern = (WithPattern) pattern; + return builder; + } + + public TableScanPushDown apply( + BiFunction pushDownFunction) { + return new TableScanPushDown<>(pattern, pushDownFunction); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/storage/Table.java b/core/src/main/java/org/opensearch/sql/storage/Table.java index f43531e2a6..ae0aaaf17b 100644 --- a/core/src/main/java/org/opensearch/sql/storage/Table.java +++ b/core/src/main/java/org/opensearch/sql/storage/Table.java @@ -11,6 +11,7 @@ import org.opensearch.sql.executor.streaming.StreamingSource; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.read.TableScanBuilder; /** * Table. @@ -45,7 +46,9 @@ default void create(Map schema) { * * @param plan logical plan * @return physical plan + * @deprecated because of new {@link TableScanBuilder} implementation */ + @Deprecated(since = "2.5.0") PhysicalPlan implement(LogicalPlan plan); /** @@ -54,11 +57,22 @@ default void create(Map schema) { * * @param plan logical plan. * @return logical plan. + * @deprecated because of new {@link TableScanBuilder} implementation */ + @Deprecated(since = "2.5.0") default LogicalPlan optimize(LogicalPlan plan) { return plan; } + /** + * Create table scan builder for logical to physical transformation. + * + * @return table scan builder + */ + default TableScanBuilder createScanBuilder() { + return null; // TODO: Enforce all subclasses to implement this later + } + /** * Translate {@link Table} to {@link StreamingSource} if possible. */ diff --git a/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java b/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java new file mode 100644 index 0000000000..c0fdf36e70 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/storage/read/TableScanBuilder.java @@ -0,0 +1,111 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.storage.read; + +import java.util.Collections; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.storage.TableScanOperator; + +/** + * A TableScanBuilder represents transition state between logical planning and physical planning + * for table scan operator. The concrete implementation class gets involved in the logical + * optimization through this abstraction and thus get the chance to handle push down optimization + * without intruding core engine. + */ +public abstract class TableScanBuilder extends LogicalPlan { + + /** + * Construct and initialize children to empty list. + */ + public TableScanBuilder() { + super(Collections.emptyList()); + } + + /** + * Build table scan operator. + * + * @return table scan operator + */ + public abstract TableScanOperator build(); + + /** + * Can a given filter operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param filter logical filter operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownFilter(LogicalFilter filter) { + return false; + } + + /** + * Can a given aggregate operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param aggregation logical aggregate operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownAggregation(LogicalAggregation aggregation) { + return false; + } + + /** + * Can a given sort operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param sort logical sort operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownSort(LogicalSort sort) { + return false; + } + + /** + * Can a given limit operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param limit logical limit operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownLimit(LogicalLimit limit) { + return false; + } + + /** + * Can a given project operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param project logical project operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownProject(LogicalProject project) { + return false; + } + + /** + * Can a given highlight operator be pushed down to table scan builder. Assume no such support + * by default unless subclass override this. + * + * @param highlight logical highlight operator + * @return true if pushed down, otherwise false + */ + public boolean pushDownHighlight(LogicalHighlight highlight) { + return false; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitTableScanBuilder(this, context); + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index 3a6a95764c..2322e4684e 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -36,6 +36,7 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; @@ -55,6 +56,8 @@ import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; @ExtendWith(MockitoExtension.class) class DefaultImplementorTest { @@ -197,4 +200,16 @@ public void visitWindowOperatorShouldReturnPhysicalWindowOperator() { assertEquals(physicalPlan, logicalPlan.accept(implementor, null)); } + + @Test + public void visitTableScanBuilderShouldBuildTableScanOperator() { + TableScanOperator tableScanOperator = Mockito.mock(TableScanOperator.class); + TableScanBuilder tableScanBuilder = new TableScanBuilder() { + @Override + public TableScanOperator build() { + return tableScanOperator; + } + }; + assertEquals(tableScanOperator, tableScanBuilder.accept(implementor, null)); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index 03eeb9c626..33c6b02c38 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -32,6 +32,8 @@ import org.opensearch.sql.expression.aggregation.Aggregator; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; /** * Todo. Temporary added for UT coverage, Will be removed. @@ -72,6 +74,15 @@ public void testAbstractPlanNodeVisitorShouldReturnNull() { assertNull(relation.accept(new LogicalPlanNodeVisitor() { }, null)); + LogicalPlan tableScanBuilder = new TableScanBuilder() { + @Override + public TableScanOperator build() { + return null; + } + }; + assertNull(tableScanBuilder.accept(new LogicalPlanNodeVisitor() { + }, null)); + LogicalPlan filter = LogicalPlanDSL.filter(relation, expression); assertNull(filter.accept(new LogicalPlanNodeVisitor() { }, null)); diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java index 9f3035888f..e2510ec464 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java @@ -7,29 +7,53 @@ package org.opensearch.sql.planner.optimizer; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; import static org.opensearch.sql.data.model.ExprValueUtils.longValue; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.LONG; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; +import com.google.common.collect.ImmutableList; +import java.util.Collections; +import java.util.Map; import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; -import org.opensearch.sql.analysis.AnalyzerTestBase; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.Spy; +import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.planner.logical.LogicalPlan; -import org.springframework.context.annotation.Configuration; -import org.springframework.test.context.ContextConfiguration; -import org.springframework.test.context.junit.jupiter.SpringExtension; - -@Configuration -@ExtendWith(SpringExtension.class) -@ContextConfiguration(classes = {AnalyzerTestBase.class}) -class LogicalPlanOptimizerTest extends AnalyzerTestBase { +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; + +@ExtendWith(MockitoExtension.class) +class LogicalPlanOptimizerTest { + + @Mock + private Table table; + + @Spy + private TableScanBuilder tableScanBuilder; + + @BeforeEach + void setUp() { + when(table.createScanBuilder()).thenReturn(tableScanBuilder); + } + /** * Filter - Filter --> Filter. */ @@ -37,7 +61,7 @@ class LogicalPlanOptimizerTest extends AnalyzerTestBase { void filter_merge_filter() { assertEquals( filter( - relation("schema", table), + tableScanBuilder, DSL.and(DSL.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(2))), DSL.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(1)))) ), @@ -61,7 +85,7 @@ void push_filter_under_sort() { assertEquals( sort( filter( - relation("schema", table), + tableScanBuilder, DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) ), Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) @@ -86,7 +110,7 @@ void multiple_filter_should_eventually_be_merged() { assertEquals( sort( filter( - relation("schema", table), + tableScanBuilder, DSL.and(DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), DSL.less(DSL.ref("longV", INTEGER), DSL.literal(longValue(1L)))) ), @@ -107,6 +131,145 @@ void multiple_filter_should_eventually_be_merged() { ); } + @Test + void default_table_scan_builder_should_not_push_down_anything() { + LogicalPlan[] plans = { + project( + relation("schema", table), + DSL.named("i", DSL.ref("intV", INTEGER)) + ), + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) + ), + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG)))), + sort( + relation("schema", table), + Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER))), + limit( + relation("schema", table), + 1, 1) + }; + + for (LogicalPlan plan : plans) { + assertEquals(plan, optimize(plan)); + } + } + + @Test + void table_scan_builder_support_project_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownProject(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + project( + relation("schema", table), + DSL.named("i", DSL.ref("intV", INTEGER))) + ) + ); + } + + @Test + void table_scan_builder_support_filter_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownFilter(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))) + ) + ); + } + + @Test + void table_scan_builder_support_aggregation_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownAggregation(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG)))) + ) + ); + } + + @Test + void table_scan_builder_support_sort_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownSort(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + sort( + relation("schema", table), + Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER))) + ) + ); + } + + @Test + void table_scan_builder_support_limit_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownLimit(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + limit( + relation("schema", table), + 1, 1) + ) + ); + } + + @Test + void table_scan_builder_support_highlight_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownHighlight(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + highlight( + relation("schema", table), + DSL.literal("*"), + Collections.emptyMap()) + ) + ); + } + + @Test + void table_not_support_scan_builder_should_not_be_impact() { + Mockito.reset(table, tableScanBuilder); + Table table = new Table() { + @Override + public Map getFieldTypes() { + return null; + } + + @Override + public PhysicalPlan implement(LogicalPlan plan) { + return null; + } + }; + + assertEquals( + relation("schema", table), + optimize(relation("schema", table)) + ); + } + private LogicalPlan optimize(LogicalPlan plan) { final LogicalPlanOptimizer optimizer = LogicalPlanOptimizer.create(); final LogicalPlan optimize = optimizer.optimize(plan); diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java index ad7c7c50dc..61d192362a 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/pattern/PatternsTest.java @@ -13,7 +13,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; +import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.planner.logical.LogicalFilter; import org.opensearch.sql.planner.logical.LogicalPlan; @ExtendWith(MockitoExtension.class) @@ -26,5 +28,12 @@ class PatternsTest { void source_is_empty() { when(plan.getChild()).thenReturn(Collections.emptyList()); assertFalse(Patterns.source().getFunction().apply(plan).isPresent()); + assertFalse(Patterns.source(null).getProperty().getFunction().apply(plan).isPresent()); + } + + @Test + void table_is_empty() { + plan = Mockito.mock(LogicalFilter.class); + assertFalse(Patterns.table().getFunction().apply(plan).isPresent()); } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexAgg.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexAgg.java deleted file mode 100644 index 84bfb47a08..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexAgg.java +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical; - -import com.google.common.collect.ImmutableList; -import java.util.List; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.Setter; -import lombok.ToString; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.NamedExpression; -import org.opensearch.sql.expression.aggregation.NamedAggregator; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; - -/** - * Logical Index Scan Aggregation Operation. - */ -@Getter -@ToString -@EqualsAndHashCode(callSuper = false) -public class OpenSearchLogicalIndexAgg extends LogicalPlan { - - private final String relationName; - - /** - * Filter Condition. - */ - @Setter - private Expression filter; - - /** - * Aggregation List. - */ - @Setter - private List aggregatorList; - - /** - * Group List. - */ - @Setter - private List groupByList; - - /** - * Sort List. - */ - @Setter - private List> sortList; - - /** - * ElasticsearchLogicalIndexAgg Constructor. - */ - @Builder - public OpenSearchLogicalIndexAgg( - String relationName, - Expression filter, - List aggregatorList, - List groupByList, - List> sortList) { - super(ImmutableList.of()); - this.relationName = relationName; - this.filter = filter; - this.aggregatorList = aggregatorList; - this.groupByList = groupByList; - this.sortList = sortList; - } - - @Override - public R accept(LogicalPlanNodeVisitor visitor, C context) { - return visitor.visitNode(this, context); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScan.java deleted file mode 100644 index d182b5f84d..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScan.java +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical; - -import com.google.common.collect.ImmutableList; -import java.util.List; -import java.util.Set; -import lombok.Builder; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.Setter; -import lombok.ToString; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalPlanNodeVisitor; - -/** - * OpenSearch Logical Index Scan Operation. - */ -@Getter -@ToString -@EqualsAndHashCode(callSuper = false) -public class OpenSearchLogicalIndexScan extends LogicalPlan { - - /** - * Relation Name. - */ - private final String relationName; - - /** - * Filter Condition. - */ - @Setter - private Expression filter; - - /** - * Projection List. - */ - @Setter - private Set projectList; - - /** - * Sort List. - */ - @Setter - private List> sortList; - - @Setter - private Integer offset; - - @Setter - private Integer limit; - - /** - * ElasticsearchLogicalIndexScan Constructor. - */ - @Builder - public OpenSearchLogicalIndexScan( - String relationName, - Expression filter, - Set projectList, - List> sortList, - Integer limit, Integer offset) { - super(ImmutableList.of()); - this.relationName = relationName; - this.filter = filter; - this.projectList = projectList; - this.sortList = sortList; - this.limit = limit; - this.offset = offset; - } - - @Override - public R accept(LogicalPlanNodeVisitor visitor, C context) { - return visitor.visitNode(this, context); - } - - public boolean hasLimit() { - return limit != null; - } - - /** - * Test has projects or not. - * - * @return true for has projects, otherwise false. - */ - public boolean hasProjects() { - return projectList != null && !projectList.isEmpty(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalPlanOptimizerFactory.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalPlanOptimizerFactory.java deleted file mode 100644 index 77cb6b13bd..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalPlanOptimizerFactory.java +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical; - -import java.util.Arrays; -import lombok.experimental.UtilityClass; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeAggAndIndexScan; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeAggAndRelation; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeFilterAndRelation; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeLimitAndIndexScan; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeLimitAndRelation; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeSortAndIndexAgg; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeSortAndIndexScan; -import org.opensearch.sql.opensearch.planner.logical.rule.MergeSortAndRelation; -import org.opensearch.sql.opensearch.planner.logical.rule.PushProjectAndIndexScan; -import org.opensearch.sql.opensearch.planner.logical.rule.PushProjectAndRelation; -import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; - -/** - * OpenSearch storage specified logical plan optimizer. - */ -@UtilityClass -public class OpenSearchLogicalPlanOptimizerFactory { - - /** - * Create OpenSearch storage specified logical plan optimizer. - */ - public static LogicalPlanOptimizer create() { - return new LogicalPlanOptimizer(Arrays.asList( - new MergeFilterAndRelation(), - new MergeAggAndIndexScan(), - new MergeAggAndRelation(), - new MergeSortAndRelation(), - new MergeSortAndIndexScan(), - new MergeSortAndIndexAgg(), - new MergeSortAndIndexScan(), - new MergeLimitAndRelation(), - new MergeLimitAndIndexScan(), - new PushProjectAndRelation(), - new PushProjectAndIndexScan() - )); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndIndexScan.java deleted file mode 100644 index 3d4d999d12..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndIndexScan.java +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import lombok.Getter; -import lombok.experimental.Accessors; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalAggregation; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Aggregation -- Relation to IndexScanAggregation. - */ -public class MergeAggAndIndexScan implements Rule { - - private final Capture capture; - - @Accessors(fluent = true) - @Getter - private final Pattern pattern; - - /** - * Constructor of MergeAggAndIndexScan. - */ - public MergeAggAndIndexScan() { - this.capture = Capture.newCapture(); - this.pattern = typeOf(LogicalAggregation.class) - .with(source().matching(typeOf(OpenSearchLogicalIndexScan.class) - .matching(indexScan -> !indexScan.hasLimit()) - .capturedAs(capture))); - } - - @Override - public LogicalPlan apply(LogicalAggregation aggregation, - Captures captures) { - OpenSearchLogicalIndexScan indexScan = captures.get(capture); - return OpenSearchLogicalIndexAgg - .builder() - .relationName(indexScan.getRelationName()) - .filter(indexScan.getFilter()) - .aggregatorList(aggregation.getAggregatorList()) - .groupByList(aggregation.getGroupByList()) - .build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndRelation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndRelation.java deleted file mode 100644 index 2e79e7c51a..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeAggAndRelation.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import lombok.Getter; -import lombok.experimental.Accessors; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; -import org.opensearch.sql.planner.logical.LogicalAggregation; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Aggregation -- Relation to IndexScanAggregation. - */ -public class MergeAggAndRelation implements Rule { - - private final Capture relationCapture; - - @Accessors(fluent = true) - @Getter - private final Pattern pattern; - - /** - * Constructor of MergeAggAndRelation. - */ - public MergeAggAndRelation() { - this.relationCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalAggregation.class) - .with(source().matching(typeOf(LogicalRelation.class).capturedAs(relationCapture))); - } - - @Override - public LogicalPlan apply(LogicalAggregation aggregation, - Captures captures) { - LogicalRelation relation = captures.get(relationCapture); - return OpenSearchLogicalIndexAgg - .builder() - .relationName(relation.getRelationName()) - .aggregatorList(aggregation.getAggregatorList()) - .groupByList(aggregation.getGroupByList()) - .build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeFilterAndRelation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeFilterAndRelation.java deleted file mode 100644 index 19143c390e..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeFilterAndRelation.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalFilter; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Filter -- Relation to LogicalIndexScan. - */ -public class MergeFilterAndRelation implements Rule { - - private final Capture relationCapture; - private final Pattern pattern; - - /** - * Constructor of MergeFilterAndRelation. - */ - public MergeFilterAndRelation() { - this.relationCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalFilter.class) - .with(source().matching(typeOf(LogicalRelation.class).capturedAs(relationCapture))); - } - - @Override - public Pattern pattern() { - return pattern; - } - - @Override - public LogicalPlan apply(LogicalFilter filter, - Captures captures) { - LogicalRelation relation = captures.get(relationCapture); - return OpenSearchLogicalIndexScan - .builder() - .relationName(relation.getRelationName()) - .filter(filter.getCondition()) - .build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndIndexScan.java deleted file mode 100644 index 9d880bb4dc..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndIndexScan.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import lombok.Getter; -import lombok.experimental.Accessors; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalLimit; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.optimizer.Rule; - -@Getter -public class MergeLimitAndIndexScan implements Rule { - - private final Capture indexScanCapture; - - @Accessors(fluent = true) - private final Pattern pattern; - - /** - * Constructor of MergeLimitAndIndexScan. - */ - public MergeLimitAndIndexScan() { - this.indexScanCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalLimit.class) - .with(source() - .matching(typeOf(OpenSearchLogicalIndexScan.class).capturedAs(indexScanCapture))); - } - - @Override - public LogicalPlan apply(LogicalLimit plan, Captures captures) { - OpenSearchLogicalIndexScan indexScan = captures.get(indexScanCapture); - OpenSearchLogicalIndexScan.OpenSearchLogicalIndexScanBuilder builder = - OpenSearchLogicalIndexScan.builder(); - builder.relationName(indexScan.getRelationName()) - .filter(indexScan.getFilter()) - .offset(plan.getOffset()) - .limit(plan.getLimit()); - if (indexScan.getSortList() != null) { - builder.sortList(indexScan.getSortList()); - } - return builder.build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndRelation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndRelation.java deleted file mode 100644 index 8a170aaa4a..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeLimitAndRelation.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import lombok.Getter; -import lombok.experimental.Accessors; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalLimit; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.optimizer.Rule; - -@Getter -public class MergeLimitAndRelation implements Rule { - - private final Capture relationCapture; - - @Accessors(fluent = true) - private final Pattern pattern; - - /** - * Constructor of MergeLimitAndRelation. - */ - public MergeLimitAndRelation() { - this.relationCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalLimit.class) - .with(source().matching(typeOf(LogicalRelation.class).capturedAs(relationCapture))); - } - - @Override - public LogicalPlan apply(LogicalLimit plan, Captures captures) { - LogicalRelation relation = captures.get(relationCapture); - return OpenSearchLogicalIndexScan.builder() - .relationName(relation.getRelationName()) - .offset(plan.getOffset()) - .limit(plan.getLimit()) - .build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexAgg.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexAgg.java deleted file mode 100644 index 57dac4dcf1..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexAgg.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; -import java.util.stream.Collectors; -import lombok.Getter; -import lombok.experimental.Accessors; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.expression.aggregation.NamedAggregator; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalSort; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Sort -- IndexScanAggregation to IndexScanAggregation. - */ -public class MergeSortAndIndexAgg implements Rule { - - private final Capture indexAggCapture; - - @Accessors(fluent = true) - @Getter - private final Pattern pattern; - - /** - * Constructor of MergeAggAndIndexScan. - */ - public MergeSortAndIndexAgg() { - this.indexAggCapture = Capture.newCapture(); - final AtomicReference sortRef = new AtomicReference<>(); - - this.pattern = typeOf(LogicalSort.class) - .matching(OptimizationRuleUtils::sortByFieldsOnly) - .matching(sort -> { - sortRef.set(sort); - return true; - }) - .with(source().matching(typeOf(OpenSearchLogicalIndexAgg.class) - .matching(indexAgg -> !hasAggregatorInSortBy(sortRef.get(), indexAgg)) - .capturedAs(indexAggCapture))); - } - - @Override - public LogicalPlan apply(LogicalSort sort, - Captures captures) { - OpenSearchLogicalIndexAgg indexAgg = captures.get(indexAggCapture); - return OpenSearchLogicalIndexAgg.builder() - .relationName(indexAgg.getRelationName()) - .filter(indexAgg.getFilter()) - .groupByList(indexAgg.getGroupByList()) - .aggregatorList(indexAgg.getAggregatorList()) - .sortList(sort.getSortList()) - .build(); - } - - private boolean hasAggregatorInSortBy(LogicalSort sort, OpenSearchLogicalIndexAgg agg) { - final Set aggregatorNames = - agg.getAggregatorList().stream().map(NamedAggregator::getName).collect(Collectors.toSet()); - for (Pair sortPair : sort.getSortList()) { - if (aggregatorNames.contains(((ReferenceExpression) sortPair.getRight()).getAttr())) { - return true; - } - } - return false; - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexScan.java deleted file mode 100644 index 337f09308c..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndIndexScan.java +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import java.util.List; -import java.util.stream.Collectors; -import java.util.stream.Stream; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalSort; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Sort with IndexScan only when Sort by fields. - */ -public class MergeSortAndIndexScan implements Rule { - - private final Capture indexScanCapture; - private final Pattern pattern; - - /** - * Constructor of MergeSortAndRelation. - */ - public MergeSortAndIndexScan() { - this.indexScanCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalSort.class).matching(OptimizationRuleUtils::sortByFieldsOnly) - .with(source() - .matching(typeOf(OpenSearchLogicalIndexScan.class).capturedAs(indexScanCapture))); - } - - @Override - public Pattern pattern() { - return pattern; - } - - @Override - public LogicalPlan apply(LogicalSort sort, - Captures captures) { - OpenSearchLogicalIndexScan indexScan = captures.get(indexScanCapture); - - return OpenSearchLogicalIndexScan - .builder() - .relationName(indexScan.getRelationName()) - .filter(indexScan.getFilter()) - .sortList(mergeSortList(indexScan.getSortList(), sort.getSortList())) - .build(); - } - - private List> mergeSortList(List> l1, List> l2) { - if (null == l1) { - return l2; - } else { - return Stream.concat(l1.stream(), l2.stream()).collect(Collectors.toList()); - } - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndRelation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndRelation.java deleted file mode 100644 index 3ba3c7f645..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/MergeSortAndRelation.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.logical.LogicalSort; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Merge Sort with Relation only when Sort by fields. - */ -public class MergeSortAndRelation implements Rule { - - private final Capture relationCapture; - private final Pattern pattern; - - /** - * Constructor of MergeSortAndRelation. - */ - public MergeSortAndRelation() { - this.relationCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalSort.class).matching(OptimizationRuleUtils::sortByFieldsOnly) - .with(source().matching(typeOf(LogicalRelation.class).capturedAs(relationCapture))); - } - - @Override - public Pattern pattern() { - return pattern; - } - - @Override - public LogicalPlan apply(LogicalSort sort, - Captures captures) { - LogicalRelation relation = captures.get(relationCapture); - return OpenSearchLogicalIndexScan - .builder() - .relationName(relation.getRelationName()) - .sortList(sort.getSortList()) - .build(); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/OptimizationRuleUtils.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/OptimizationRuleUtils.java deleted file mode 100644 index aa1ffa9e4c..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/OptimizationRuleUtils.java +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Set; -import lombok.experimental.UtilityClass; -import org.opensearch.sql.expression.ExpressionNodeVisitor; -import org.opensearch.sql.expression.NamedExpression; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.planner.logical.LogicalSort; - -@UtilityClass -public class OptimizationRuleUtils { - - /** - * Does the sort list only contain {@link ReferenceExpression}. - * - * @param logicalSort LogicalSort. - * @return true only contain ReferenceExpression, otherwise false. - */ - public static boolean sortByFieldsOnly(LogicalSort logicalSort) { - return logicalSort.getSortList().stream() - .map(sort -> sort.getRight() instanceof ReferenceExpression) - .reduce(true, Boolean::logicalAnd); - } - - /** - * Find reference expression from expression. - * @param expressions a list of expression. - * - * @return a list of ReferenceExpression - */ - public static Set findReferenceExpressions( - List expressions) { - Set projectList = new HashSet<>(); - for (NamedExpression namedExpression : expressions) { - projectList.addAll(findReferenceExpression(namedExpression)); - } - return projectList; - } - - /** - * Find reference expression from expression. - * @param expression expression. - * - * @return a list of ReferenceExpression - */ - public static List findReferenceExpression( - NamedExpression expression) { - List results = new ArrayList<>(); - expression.accept(new ExpressionNodeVisitor() { - @Override - public Object visitReference(ReferenceExpression node, Object context) { - return results.add(node); - } - }, null); - return results; - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndIndexScan.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndIndexScan.java deleted file mode 100644 index 43714282fb..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndIndexScan.java +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.opensearch.planner.logical.rule.OptimizationRuleUtils.findReferenceExpressions; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import java.util.Set; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalProject; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Push Project list into ElasticsearchLogicalIndexScan. - */ -public class PushProjectAndIndexScan implements Rule { - - private final Capture indexScanCapture; - - private final Pattern pattern; - - private Set pushDownProjects; - - /** - * Constructor of MergeProjectAndIndexScan. - */ - public PushProjectAndIndexScan() { - this.indexScanCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalProject.class).matching( - project -> { - pushDownProjects = findReferenceExpressions(project.getProjectList()); - return !pushDownProjects.isEmpty(); - }).with(source() - .matching(typeOf(OpenSearchLogicalIndexScan.class) - .matching(indexScan -> !indexScan.hasProjects()) - .capturedAs(indexScanCapture))); - - } - - @Override - public Pattern pattern() { - return pattern; - } - - @Override - public LogicalPlan apply(LogicalProject project, - Captures captures) { - OpenSearchLogicalIndexScan indexScan = captures.get(indexScanCapture); - indexScan.setProjectList(pushDownProjects); - return new LogicalProject(indexScan, project.getProjectList(), - project.getNamedParseExpressions()); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndRelation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndRelation.java deleted file mode 100644 index a29a1df466..0000000000 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/planner/logical/rule/PushProjectAndRelation.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical.rule; - -import static com.facebook.presto.matching.Pattern.typeOf; -import static org.opensearch.sql.opensearch.planner.logical.rule.OptimizationRuleUtils.findReferenceExpressions; -import static org.opensearch.sql.planner.optimizer.pattern.Patterns.source; - -import com.facebook.presto.matching.Capture; -import com.facebook.presto.matching.Captures; -import com.facebook.presto.matching.Pattern; -import java.util.Set; -import org.opensearch.sql.expression.ReferenceExpression; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalProject; -import org.opensearch.sql.planner.logical.LogicalRelation; -import org.opensearch.sql.planner.optimizer.Rule; - -/** - * Push Project list into Relation. The transformed plan is Project - IndexScan - */ -public class PushProjectAndRelation implements Rule { - - private final Capture relationCapture; - - private final Pattern pattern; - - private Set pushDownProjects; - - /** - * Constructor of MergeProjectAndRelation. - */ - public PushProjectAndRelation() { - this.relationCapture = Capture.newCapture(); - this.pattern = typeOf(LogicalProject.class) - .matching(project -> { - pushDownProjects = findReferenceExpressions(project.getProjectList()); - return !pushDownProjects.isEmpty(); - }) - .with(source().matching(typeOf(LogicalRelation.class).capturedAs(relationCapture))); - } - - @Override - public Pattern pattern() { - return pattern; - } - - @Override - public LogicalPlan apply(LogicalProject project, - Captures captures) { - LogicalRelation relation = captures.get(relationCapture); - return new LogicalProject( - OpenSearchLogicalIndexScan - .builder() - .relationName(relation.getRelationName()) - .projectList(findReferenceExpressions(project.getProjectList())) - .build(), - project.getProjectList(), - project.getNamedParseExpressions() - ); - } -} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java index c26413c622..439a970a4f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilder.java @@ -9,6 +9,8 @@ import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME; import static org.opensearch.search.sort.SortOrder.ASC; +import com.google.common.collect.Lists; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; @@ -24,7 +26,9 @@ import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.opensearch.search.sort.FieldSortBuilder; import org.opensearch.search.sort.SortBuilder; +import org.opensearch.search.sort.SortBuilders; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.common.utils.StringUtils; @@ -158,6 +162,11 @@ public void pushDownAggregation( * @param sortBuilders sortBuilders. */ public void pushDownSort(List> sortBuilders) { + // TODO: Sort by _doc is added when filter push down. Remove both logic once doctest fixed. + if (isSortByDocOnly()) { + sourceBuilder.sorts().clear(); + } + for (SortBuilder sortBuilder : sortBuilders) { sourceBuilder.sort(sortBuilder); } @@ -220,4 +229,12 @@ public void pushTypeMapping(Map typeMapping) { private boolean isBoolFilterQuery(QueryBuilder current) { return (current instanceof BoolQueryBuilder); } + + private boolean isSortByDocOnly() { + List> sorts = sourceBuilder.sorts(); + if (sorts != null) { + return sorts.equals(Arrays.asList(SortBuilders.fieldSort(DOC_FIELD_NAME))); + } + return false; + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java index 00e8a5154c..7459300caa 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/CompositeAggregationParser.java @@ -18,12 +18,14 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; import org.opensearch.search.aggregations.Aggregations; import org.opensearch.search.aggregations.bucket.composite.CompositeAggregation; /** * Composite Aggregation Parser which include composite aggregation and metric parsers. */ +@EqualsAndHashCode public class CompositeAggregationParser implements OpenSearchAggregationResponseParser { private final MetricParserHelper metricsParser; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java index cfcba82c18..8358379be0 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/FilterParser.java @@ -15,6 +15,7 @@ import java.util.Map; import lombok.Builder; +import lombok.EqualsAndHashCode; import lombok.Getter; import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.bucket.filter.Filter; @@ -25,6 +26,7 @@ * do nothing and return the result from metricsParser. */ @Builder +@EqualsAndHashCode public class FilterParser implements MetricParser { private final MetricParser metricsParser; diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java index 54b9305f49..d5c0141ad2 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/MetricParserHelper.java @@ -17,6 +17,7 @@ import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; import lombok.RequiredArgsConstructor; import org.opensearch.search.aggregations.Aggregation; import org.opensearch.search.aggregations.Aggregations; @@ -25,6 +26,7 @@ /** * Parse multiple metrics in one bucket. */ +@EqualsAndHashCode @RequiredArgsConstructor public class MetricParserHelper { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java index 88d9604137..384e07ad8f 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SingleValueParser.java @@ -17,6 +17,7 @@ import java.util.Collections; import java.util.Map; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.search.aggregations.Aggregation; @@ -25,6 +26,7 @@ /** * {@link NumericMetricsAggregation.SingleValue} metric parser. */ +@EqualsAndHashCode @RequiredArgsConstructor public class SingleValueParser implements MetricParser { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java index 5928b7efc9..c80b75de05 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/StatsParser.java @@ -18,6 +18,7 @@ import java.util.Collections; import java.util.Map; import java.util.function.Function; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.search.aggregations.Aggregation; @@ -26,6 +27,7 @@ /** * {@link ExtendedStats} metric parser. */ +@EqualsAndHashCode @RequiredArgsConstructor public class StatsParser implements MetricParser { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/TopHitsParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/TopHitsParser.java index 4a3a346a84..a98e1b4ce3 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/TopHitsParser.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/TopHitsParser.java @@ -10,6 +10,7 @@ import java.util.Collections; import java.util.Map; import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.search.aggregations.Aggregation; @@ -18,6 +19,7 @@ /** * {@link TopHits} metric parser. */ +@EqualsAndHashCode @RequiredArgsConstructor public class TopHitsParser implements MetricParser { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java index 26082abed1..c694769b89 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/OpenSearchIndex.java @@ -8,41 +8,27 @@ import com.google.common.annotations.VisibleForTesting; import java.util.HashMap; -import java.util.List; import java.util.Map; -import java.util.stream.Collectors; import lombok.RequiredArgsConstructor; -import org.apache.commons.lang3.tuple.Pair; -import org.opensearch.index.query.QueryBuilder; -import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.sql.common.setting.Settings; -import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.opensearch.client.OpenSearchClient; import org.opensearch.sql.opensearch.data.type.OpenSearchDataType; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalPlanOptimizerFactory; import org.opensearch.sql.opensearch.planner.physical.ADOperator; import org.opensearch.sql.opensearch.planner.physical.MLCommonsOperator; import org.opensearch.sql.opensearch.planner.physical.MLOperator; import org.opensearch.sql.opensearch.request.OpenSearchRequest; import org.opensearch.sql.opensearch.request.system.OpenSearchDescribeIndexRequest; -import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; -import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; -import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; -import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; -import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; +import org.opensearch.sql.opensearch.storage.scan.OpenSearchIndexScanBuilder; import org.opensearch.sql.planner.DefaultImplementor; import org.opensearch.sql.planner.logical.LogicalAD; -import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalML; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.logical.LogicalRelation; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; /** OpenSearch table (index) implementation. */ public class OpenSearchIndex implements Table { @@ -122,98 +108,30 @@ public Integer getMaxResultWindow() { */ @Override public PhysicalPlan implement(LogicalPlan plan) { - OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, indexName, - getMaxResultWindow(), new OpenSearchExprValueFactory(getFieldTypes())); - - /* - * Visit logical plan with index scan as context so logical operators visited, such as - * aggregation, filter, will accumulate (push down) OpenSearch query and aggregation DSL on - * index scan. - */ - return plan.accept(new OpenSearchDefaultImplementor(indexScan, client), indexScan); + // TODO: Leave it here to avoid impact Prometheus and AD operators. Need to move to Planner. + return plan.accept(new OpenSearchDefaultImplementor(client), null); } @Override public LogicalPlan optimize(LogicalPlan plan) { - return OpenSearchLogicalPlanOptimizerFactory.create().optimize(plan); + // No-op because optimization already done in Planner + return plan; + } + + @Override + public TableScanBuilder createScanBuilder() { + OpenSearchIndexScan indexScan = new OpenSearchIndexScan(client, settings, indexName, + getMaxResultWindow(), new OpenSearchExprValueFactory(getFieldTypes())); + return new OpenSearchIndexScanBuilder(indexScan); } @VisibleForTesting @RequiredArgsConstructor public static class OpenSearchDefaultImplementor extends DefaultImplementor { - private final OpenSearchIndexScan indexScan; private final OpenSearchClient client; - @Override - public PhysicalPlan visitNode(LogicalPlan plan, OpenSearchIndexScan context) { - if (plan instanceof OpenSearchLogicalIndexScan) { - return visitIndexScan((OpenSearchLogicalIndexScan) plan, context); - } else if (plan instanceof OpenSearchLogicalIndexAgg) { - return visitIndexAggregation((OpenSearchLogicalIndexAgg) plan, context); - } else { - throw new IllegalStateException(StringUtils.format("unexpected plan node type %s", - plan.getClass())); - } - } - - /** - * Implement ElasticsearchLogicalIndexScan. - */ - public PhysicalPlan visitIndexScan(OpenSearchLogicalIndexScan node, - OpenSearchIndexScan context) { - if (null != node.getSortList()) { - final SortQueryBuilder builder = new SortQueryBuilder(); - context.getRequestBuilder().pushDownSort(node.getSortList().stream() - .map(sort -> builder.build(sort.getValue(), sort.getKey())) - .collect(Collectors.toList())); - } - - if (null != node.getFilter()) { - FilterQueryBuilder queryBuilder = new FilterQueryBuilder(new DefaultExpressionSerializer()); - QueryBuilder query = queryBuilder.build(node.getFilter()); - context.getRequestBuilder().pushDown(query); - } - - if (node.getLimit() != null) { - context.getRequestBuilder().pushDownLimit(node.getLimit(), node.getOffset()); - } - - if (node.hasProjects()) { - context.getRequestBuilder().pushDownProjects(node.getProjectList()); - } - return indexScan; - } - - /** - * Implement ElasticsearchLogicalIndexAgg. - */ - public PhysicalPlan visitIndexAggregation(OpenSearchLogicalIndexAgg node, - OpenSearchIndexScan context) { - if (node.getFilter() != null) { - FilterQueryBuilder queryBuilder = new FilterQueryBuilder( - new DefaultExpressionSerializer()); - QueryBuilder query = queryBuilder.build(node.getFilter()); - context.getRequestBuilder().pushDown(query); - } - AggregationQueryBuilder builder = - new AggregationQueryBuilder(new DefaultExpressionSerializer()); - Pair, OpenSearchAggregationResponseParser> aggregationBuilder = - builder.buildAggregationBuilder(node.getAggregatorList(), - node.getGroupByList(), node.getSortList()); - context.getRequestBuilder().pushDownAggregation(aggregationBuilder); - context.getRequestBuilder().pushTypeMapping( - builder.buildTypeMapping(node.getAggregatorList(), - node.getGroupByList())); - return indexScan; - } - - @Override - public PhysicalPlan visitRelation(LogicalRelation node, OpenSearchIndexScan context) { - return indexScan; - } - @Override public PhysicalPlan visitMLCommons(LogicalMLCommons node, OpenSearchIndexScan context) { return new MLCommonsOperator(visitChild(node, context), node.getAlgorithm(), @@ -231,12 +149,5 @@ public PhysicalPlan visitML(LogicalML node, OpenSearchIndexScan context) { return new MLOperator(visitChild(node, context), node.getArguments(), client.getNodeClient()); } - - @Override - public PhysicalPlan visitHighlight(LogicalHighlight node, OpenSearchIndexScan context) { - context.getRequestBuilder().pushDownHighlight( - StringUtils.unquoteText(node.getHighlightField().toString()), node.getArguments()); - return visitChild(node, context); - } } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java new file mode 100644 index 0000000000..e52fc566cd --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanAggregationBuilder.java @@ -0,0 +1,92 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; +import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Index scan builder for aggregate query used by {@link OpenSearchIndexScanBuilder} internally. + */ +class OpenSearchIndexScanAggregationBuilder extends TableScanBuilder { + + /** OpenSearch index scan to be optimized. */ + private final OpenSearchIndexScan indexScan; + + /** Aggregators pushed down. */ + private List aggregatorList; + + /** Grouping items pushed down. */ + private List groupByList; + + /** Sorting items pushed down. */ + private List> sortList; + + /** + * Initialize with given index scan and perform push-down optimization later. + * + * @param indexScan index scan not fully optimized yet + */ + OpenSearchIndexScanAggregationBuilder(OpenSearchIndexScan indexScan) { + this.indexScan = indexScan; + } + + @Override + public TableScanOperator build() { + AggregationQueryBuilder builder = + new AggregationQueryBuilder(new DefaultExpressionSerializer()); + Pair, OpenSearchAggregationResponseParser> aggregationBuilder = + builder.buildAggregationBuilder(aggregatorList, groupByList, sortList); + indexScan.getRequestBuilder().pushDownAggregation(aggregationBuilder); + indexScan.getRequestBuilder().pushTypeMapping( + builder.buildTypeMapping(aggregatorList, groupByList)); + return indexScan; + } + + @Override + public boolean pushDownAggregation(LogicalAggregation aggregation) { + aggregatorList = aggregation.getAggregatorList(); + groupByList = aggregation.getGroupByList(); + return true; + } + + @Override + public boolean pushDownSort(LogicalSort sort) { + if (hasAggregatorInSortBy(sort)) { + return false; + } + + sortList = sort.getSortList(); + return true; + } + + private boolean hasAggregatorInSortBy(LogicalSort sort) { + final Set aggregatorNames = + aggregatorList.stream().map(NamedAggregator::getName).collect(Collectors.toSet()); + for (Pair sortPair : sort.getSortList()) { + if (aggregatorNames.contains(((ReferenceExpression) sortPair.getRight()).getAttr())) { + return true; + } + } + return false; + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java new file mode 100644 index 0000000000..d7483cfcf0 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanBuilder.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import com.google.common.annotations.VisibleForTesting; +import lombok.EqualsAndHashCode; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.planner.logical.LogicalAggregation; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Table scan builder that builds table scan operator for OpenSearch. The actual work is performed + * by delegated builder internally. This is to avoid conditional check of different push down logic + * for non-aggregate and aggregate query everywhere. + */ +public class OpenSearchIndexScanBuilder extends TableScanBuilder { + + /** + * Delegated index scan builder for non-aggregate or aggregate query. + */ + @EqualsAndHashCode.Include + private TableScanBuilder delegate; + + /** Is limit operator pushed down. */ + private boolean isLimitPushedDown = false; + + @VisibleForTesting + OpenSearchIndexScanBuilder(TableScanBuilder delegate) { + this.delegate = delegate; + } + + /** + * Initialize with given index scan. + * + * @param indexScan index scan to optimize + */ + public OpenSearchIndexScanBuilder(OpenSearchIndexScan indexScan) { + this.delegate = new OpenSearchIndexScanQueryBuilder(indexScan); + } + + @Override + public TableScanOperator build() { + return delegate.build(); + } + + @Override + public boolean pushDownFilter(LogicalFilter filter) { + return delegate.pushDownFilter(filter); + } + + @Override + public boolean pushDownAggregation(LogicalAggregation aggregation) { + if (isLimitPushedDown) { + return false; + } + + // Switch to builder for aggregate query which has different push down logic + // for later filter, sort and limit operator. + delegate = new OpenSearchIndexScanAggregationBuilder( + (OpenSearchIndexScan) delegate.build()); + + return delegate.pushDownAggregation(aggregation); + } + + @Override + public boolean pushDownSort(LogicalSort sort) { + if (!sortByFieldsOnly(sort)) { + return false; + } + return delegate.pushDownSort(sort); + } + + @Override + public boolean pushDownLimit(LogicalLimit limit) { + // Assume limit push down happening on OpenSearchIndexScanQueryBuilder + isLimitPushedDown = true; + return delegate.pushDownLimit(limit); + } + + @Override + public boolean pushDownProject(LogicalProject project) { + return delegate.pushDownProject(project); + } + + @Override + public boolean pushDownHighlight(LogicalHighlight highlight) { + return delegate.pushDownHighlight(highlight); + } + + private boolean sortByFieldsOnly(LogicalSort sort) { + return sort.getSortList().stream() + .map(sortItem -> sortItem.getRight() instanceof ReferenceExpression) + .reduce(true, Boolean::logicalAnd); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java new file mode 100644 index 0000000000..7190d58000 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanQueryBuilder.java @@ -0,0 +1,133 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.scan; + +import com.google.common.annotations.VisibleForTesting; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.ExpressionNodeVisitor; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; +import org.opensearch.sql.opensearch.storage.script.sort.SortQueryBuilder; +import org.opensearch.sql.opensearch.storage.serialization.DefaultExpressionSerializer; +import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalHighlight; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Index scan builder for simple non-aggregate query used by + * {@link OpenSearchIndexScanBuilder} internally. + */ +@VisibleForTesting +class OpenSearchIndexScanQueryBuilder extends TableScanBuilder { + + /** OpenSearch index scan to be optimized. */ + @EqualsAndHashCode.Include + private final OpenSearchIndexScan indexScan; + + /** + * Initialize with given index scan and perform push-down optimization later. + * + * @param indexScan index scan not optimized yet + */ + OpenSearchIndexScanQueryBuilder(OpenSearchIndexScan indexScan) { + this.indexScan = indexScan; + } + + @Override + public TableScanOperator build() { + return indexScan; + } + + @Override + public boolean pushDownFilter(LogicalFilter filter) { + FilterQueryBuilder queryBuilder = new FilterQueryBuilder( + new DefaultExpressionSerializer()); + QueryBuilder query = queryBuilder.build(filter.getCondition()); + indexScan.getRequestBuilder().pushDown(query); + return true; + } + + @Override + public boolean pushDownSort(LogicalSort sort) { + List> sortList = sort.getSortList(); + final SortQueryBuilder builder = new SortQueryBuilder(); + indexScan.getRequestBuilder().pushDownSort(sortList.stream() + .map(sortItem -> builder.build(sortItem.getValue(), sortItem.getKey())) + .collect(Collectors.toList())); + return true; + } + + @Override + public boolean pushDownLimit(LogicalLimit limit) { + indexScan.getRequestBuilder().pushDownLimit(limit.getLimit(), limit.getOffset()); + return true; + } + + @Override + public boolean pushDownProject(LogicalProject project) { + indexScan.getRequestBuilder().pushDownProjects( + findReferenceExpressions(project.getProjectList())); + + // Return false intentionally to keep the original project operator + return false; + } + + @Override + public boolean pushDownHighlight(LogicalHighlight highlight) { + indexScan.getRequestBuilder().pushDownHighlight( + StringUtils.unquoteText(highlight.getHighlightField().toString()), + highlight.getArguments()); + return true; + } + + /** + * Find reference expression from expression. + * @param expressions a list of expression. + * + * @return a list of ReferenceExpression + */ + public static Set findReferenceExpressions( + List expressions) { + Set projectList = new HashSet<>(); + for (NamedExpression namedExpression : expressions) { + projectList.addAll(findReferenceExpression(namedExpression)); + } + return projectList; + } + + /** + * Find reference expression from expression. + * @param expression expression. + * + * @return a list of ReferenceExpression + */ + public static List findReferenceExpression(NamedExpression expression) { + List results = new ArrayList<>(); + expression.accept(new ExpressionNodeVisitor<>() { + @Override + public Object visitReference(ReferenceExpression node, Object context) { + return results.add(node); + } + }, null); + return results; + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java deleted file mode 100644 index 31ad2b2ee3..0000000000 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicOptimizerTest.java +++ /dev/null @@ -1,576 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical; - -import static org.junit.jupiter.api.Assertions.assertEquals; -import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; -import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; -import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; -import static org.opensearch.sql.data.type.ExprCoreType.LONG; -import static org.opensearch.sql.data.type.ExprCoreType.STRING; -import static org.opensearch.sql.opensearch.utils.Utils.indexScan; -import static org.opensearch.sql.opensearch.utils.Utils.indexScanAgg; -import static org.opensearch.sql.opensearch.utils.Utils.noProjects; -import static org.opensearch.sql.opensearch.utils.Utils.projects; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import org.apache.commons.lang3.tuple.Pair; -import org.junit.jupiter.api.Disabled; -import org.junit.jupiter.api.Test; -import org.junit.jupiter.api.extension.ExtendWith; -import org.mockito.Mock; -import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.expression.DSL; -import org.opensearch.sql.opensearch.utils.Utils; -import org.opensearch.sql.planner.logical.LogicalPlan; -import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; -import org.opensearch.sql.storage.Table; - -@ExtendWith(MockitoExtension.class) -class OpenSearchLogicOptimizerTest { - - @Mock - private Table table; - - /** - * SELECT intV as i FROM schema WHERE intV = 1. - */ - @Test - void project_filter_merge_with_relation() { - assertEquals( - project( - indexScan("schema", - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), - ImmutableSet.of(DSL.ref("intV", INTEGER))), - DSL.named("i", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - filter( - relation("schema", table), - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) - ), - DSL.named("i", DSL.ref("intV", INTEGER))) - ) - ); - } - - /** - * SELECT avg(intV) FROM schema GROUP BY string_value. - */ - @Test - void aggregation_merge_relation() { - assertEquals( - project( - indexScanAgg("schema", ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), - optimize( - project( - aggregation( - relation("schema", table), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))) - ) - ); - } - - /** - * SELECT avg(intV) FROM schema WHERE intV = 1 GROUP BY string_value. - */ - @Test - void aggregation_merge_filter_relation() { - assertEquals( - project( - indexScanAgg("schema", - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), - optimize( - project( - aggregation( - filter( - relation("schema", table), - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) - ), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))) - ) - ); - } - - @Disabled("This test should be enabled once https://github.com/opensearch-project/sql/issues/912 is fixed") - @Test - void aggregation_cant_merge_indexScan_with_project() { - assertEquals( - aggregation( - OpenSearchLogicalIndexScan.builder().relationName("schema") - .filter(DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))) - .projectList(ImmutableSet.of(DSL.ref("intV", INTEGER))) - .build(), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - optimize( - aggregation( - OpenSearchLogicalIndexScan.builder().relationName("schema") - .filter(DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))) - .projectList( - ImmutableSet.of(DSL.ref("intV", INTEGER))) - .build(), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG)))))) - ); - } - - /** - * Sort - Relation --> IndexScan. - */ - @Test - void sort_merge_with_relation() { - assertEquals( - indexScan("schema", Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER))), - optimize( - sort( - relation("schema", table), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER)) - ) - ) - ); - } - - /** - * Sort - IndexScan --> IndexScan. - */ - @Test - void sort_merge_with_indexScan() { - assertEquals( - indexScan("schema", - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER)), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG))), - optimize( - sort( - indexScan("schema", Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER))), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) - ) - ) - ); - } - - /** - * Sort - Filter - Relation --> IndexScan. - */ - @Test - void sort_filter_merge_with_relation() { - assertEquals( - indexScan("schema", - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) - ), - optimize( - sort( - filter( - relation("schema", table), - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) - ), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) - ) - ) - ); - } - - @Test - void sort_with_expression_cannot_merge_with_relation() { - assertEquals( - sort( - relation("schema", table), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) - ), - optimize( - sort( - relation("schema", table), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) - ) - ) - ); - } - - /** - * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY stringV. - */ - @Test - void sort_merge_indexagg() { - assertEquals( - project( - indexScanAgg("schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING))), - ImmutableList - .of(Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("stringV", STRING)))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), - optimize( - project( - sort( - aggregation( - relation("schema", table), - ImmutableList - .of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("stringV", STRING)) - ), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))) - ) - ); - } - - /** - * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY stringV. - */ - @Test - void sort_merge_indexagg_nulls_last() { - assertEquals( - project( - indexScanAgg("schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING))), - ImmutableList - .of(Pair.of(Sort.SortOption.DEFAULT_DESC, DSL.ref("stringV", STRING)))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), - optimize( - project( - sort( - aggregation( - relation("schema", table), - ImmutableList - .of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - Pair.of(Sort.SortOption.DEFAULT_DESC, DSL.ref("stringV", STRING)) - ), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))) - ) - ); - } - - - /** - * Can't Optimize the following query. - * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY avg(intV). - */ - @Test - void sort_refer_to_aggregator_should_not_merge_with_indexAgg() { - assertEquals( - sort( - indexScanAgg("schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - Pair.of(Sort.SortOption.DEFAULT_DESC, DSL.ref("AVG(intV)", INTEGER)) - ), - optimize( - sort( - indexScanAgg("schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - Pair.of(Sort.SortOption.DEFAULT_DESC, DSL.ref("AVG(intV)", INTEGER)) - ) - ) - ); - } - - /** - * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY stringV ASC NULL_LAST. - */ - @Test - void sort_with_customized_option_should_merge_with_indexAgg() { - assertEquals( - indexScanAgg( - "schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING))), - ImmutableList.of( - Pair.of( - new Sort.SortOption(Sort.SortOrder.ASC, Sort.NullOrder.NULL_LAST), - DSL.ref("stringV", STRING)))), - optimize( - sort( - indexScanAgg( - "schema", - ImmutableList.of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - Pair.of( - new Sort.SortOption(Sort.SortOrder.ASC, Sort.NullOrder.NULL_LAST), - DSL.ref("stringV", STRING))))); - } - - @Test - void limit_merge_with_relation() { - assertEquals( - project( - indexScan("schema", 1, 1, projects(DSL.ref("intV", INTEGER))), - DSL.named("intV", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - limit( - relation("schema", table), - 1, 1 - ), - DSL.named("intV", DSL.ref("intV", INTEGER)) - ) - ) - ); - } - - @Test - void limit_merge_with_index_scan() { - assertEquals( - project( - indexScan("schema", - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), - 1, 1, - projects(DSL.ref("intV", INTEGER)) - ), - DSL.named("intV", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - limit( - filter( - relation("schema", table), - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) - ), 1, 1 - ), - DSL.named("intV", DSL.ref("intV", INTEGER))) - ) - ); - } - - @Test - void limit_merge_with_index_scan_sort() { - assertEquals( - project( - indexScan("schema", - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))), - 1, 1, - Utils.sort(DSL.ref("longV", LONG), Sort.SortOption.DEFAULT_ASC), - projects(DSL.ref("intV", INTEGER)) - ), - DSL.named("intV", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - limit( - sort( - filter( - relation("schema", table), - DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) - ), - Pair.of(Sort.SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) - ), 1, 1 - ), - DSL.named("intV", DSL.ref("intV", INTEGER)) - ) - ) - ); - } - - @Test - void aggregation_cant_merge_index_scan_with_limit() { - assertEquals( - project( - aggregation( - indexScan("schema", 10, 0, noProjects()), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), - optimize( - project( - aggregation( - indexScan("schema", 10, 0, noProjects()), - ImmutableList - .of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)))), - ImmutableList.of(DSL.named("longV", - DSL.abs(DSL.ref("longV", LONG))))), - DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))))); - } - - @Test - void push_down_projectList_to_relation() { - assertEquals( - project( - indexScan("schema", projects(DSL.ref("intV", INTEGER))), - DSL.named("i", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - relation("schema", table), - DSL.named("i", DSL.ref("intV", INTEGER))) - ) - ); - } - - /** - * Project(intV, abs(intV)) -> Relation. - * -- will be optimized as - * Project(intV, abs(intV)) -> Relation(project=intV). - */ - @Test - void push_down_should_handle_duplication() { - assertEquals( - project( - indexScan("schema", projects(DSL.ref("intV", INTEGER))), - DSL.named("i", DSL.ref("intV", INTEGER)), - DSL.named("absi", DSL.abs(DSL.ref("intV", INTEGER))) - ), - optimize( - project( - relation("schema", table), - DSL.named("i", DSL.ref("intV", INTEGER)), - DSL.named("absi", DSL.abs(DSL.ref("intV", INTEGER)))) - ) - ); - } - - /** - * Project(ListA) -> Project(ListB) -> Relation. - * -- will be optimized as - * Project(ListA) -> Project(ListB) -> Relation(project=ListB). - */ - @Test - void only_one_project_should_be_push() { - assertEquals( - project( - project( - indexScan("schema", - projects(DSL.ref("intV", INTEGER), DSL.ref("stringV", STRING)) - ), - DSL.named("i", DSL.ref("intV", INTEGER)), - DSL.named("s", DSL.ref("stringV", STRING)) - ), - DSL.named("i", DSL.ref("intV", INTEGER)) - ), - optimize( - project( - project( - relation("schema", table), - DSL.named("i", DSL.ref("intV", INTEGER)), - DSL.named("s", DSL.ref("stringV", STRING)) - ), - DSL.named("i", DSL.ref("intV", INTEGER)) - ) - ) - ); - } - - @Test - void project_literal_no_push() { - assertEquals( - project( - relation("schema", table), - DSL.named("i", DSL.literal("str")) - ), - optimize( - project( - relation("schema", table), - DSL.named("i", DSL.literal("str")) - ) - ) - ); - } - - /** - * SELECT AVG(intV) FILTER(WHERE intV > 1) FROM schema GROUP BY stringV. - */ - @Test - void filter_aggregation_merge_relation() { - assertEquals( - project( - indexScanAgg("schema", ImmutableList.of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)) - .condition(DSL.greater(DSL.ref("intV", INTEGER), DSL.literal(1))))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - DSL.named("avg(intV) filter(where intV > 1)", DSL.ref("avg(intV)", DOUBLE))), - optimize( - project( - aggregation( - relation("schema", table), - ImmutableList.of(DSL.named("AVG(intV)", - DSL.avg(DSL.ref("intV", INTEGER)) - .condition(DSL.greater(DSL.ref("intV", INTEGER), DSL.literal(1))))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - DSL.named("avg(intV) filter(where intV > 1)", DSL.ref("avg(intV)", DOUBLE))) - ) - ); - } - - /** - * SELECT AVG(intV) FILTER(WHERE intV > 1) FROM schema WHERE longV < 1 GROUP BY stringV. - */ - @Test - void filter_aggregation_merge_filter_relation() { - assertEquals( - project( - indexScanAgg("schema", - DSL.less(DSL.ref("longV", LONG), DSL.literal(1)), - ImmutableList.of(DSL.named("avg(intV)", - DSL.avg(DSL.ref("intV", INTEGER)) - .condition(DSL.greater(DSL.ref("intV", INTEGER), DSL.literal(1))))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - DSL.named("avg(intV) filter(where intV > 1)", DSL.ref("avg(intV)", DOUBLE))), - optimize( - project( - aggregation( - filter( - relation("schema", table), - DSL.less(DSL.ref("longV", LONG), DSL.literal(1)) - ), - ImmutableList.of(DSL.named("avg(intV)", - DSL.avg(DSL.ref("intV", INTEGER)) - .condition(DSL.greater(DSL.ref("intV", INTEGER), DSL.literal(1))))), - ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), - DSL.named("avg(intV) filter(where intV > 1)", DSL.ref("avg(intV)", DOUBLE))) - ) - ); - } - - private LogicalPlan optimize(LogicalPlan plan) { - final LogicalPlanOptimizer optimizer = OpenSearchLogicalPlanOptimizerFactory.create(); - final LogicalPlan optimize = optimizer.optimize(plan); - return optimize; - } -} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScanTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScanTest.java deleted file mode 100644 index 2e10f33787..0000000000 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/planner/logical/OpenSearchLogicalIndexScanTest.java +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - - -package org.opensearch.sql.opensearch.planner.logical; - -import static org.junit.jupiter.api.Assertions.assertFalse; - -import com.google.common.collect.ImmutableSet; -import org.junit.jupiter.api.Test; - -class OpenSearchLogicalIndexScanTest { - - @Test - void has_projects() { - assertFalse(OpenSearchLogicalIndexScan.builder() - .projectList(ImmutableSet.of()).build() - .hasProjects()); - - assertFalse(OpenSearchLogicalIndexScan.builder().build().hasProjects()); - } -} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java index 43b9353190..33376ece83 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java @@ -7,41 +7,70 @@ package org.opensearch.sql.opensearch.request; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.search.sort.FieldSortBuilder.DOC_FIELD_NAME; +import static org.opensearch.search.sort.SortOrder.ASC; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.common.unit.TimeValue; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.sort.FieldSortBuilder; +import org.opensearch.search.sort.ScoreSortBuilder; +import org.opensearch.search.sort.SortBuilders; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; +import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.response.agg.SingleValueParser; @ExtendWith(MockitoExtension.class) public class OpenSearchRequestBuilderTest { - public static final TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); + private static final TimeValue DEFAULT_QUERY_TIMEOUT = TimeValue.timeValueMinutes(1L); + private static final Integer DEFAULT_OFFSET = 0; + private static final Integer DEFAULT_LIMIT = 200; + private static final Integer MAX_RESULT_WINDOW = 500; + @Mock private Settings settings; @Mock - private OpenSearchExprValueFactory factory; + private OpenSearchExprValueFactory exprValueFactory; + + private OpenSearchRequestBuilder requestBuilder; @BeforeEach void setup() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); + + requestBuilder = new OpenSearchRequestBuilder( + "test", MAX_RESULT_WINDOW, settings, exprValueFactory); } @Test void buildQueryRequest() { - Integer maxResultWindow = 500; Integer limit = 200; Integer offset = 0; - OpenSearchRequestBuilder builder = - new OpenSearchRequestBuilder("test", maxResultWindow, settings, factory); - builder.pushDownLimit(limit, offset); + requestBuilder.pushDownLimit(limit, offset); assertEquals( new OpenSearchQueryRequest( @@ -50,27 +79,145 @@ void buildQueryRequest() { .from(offset) .size(limit) .timeout(DEFAULT_QUERY_TIMEOUT), - factory), - builder.build()); + exprValueFactory), + requestBuilder.build()); } @Test void buildScrollRequestWithCorrectSize() { - Integer maxResultWindow = 500; Integer limit = 800; Integer offset = 10; - OpenSearchRequestBuilder builder = - new OpenSearchRequestBuilder("test", maxResultWindow, settings, factory); - builder.pushDownLimit(limit, offset); + requestBuilder.pushDownLimit(limit, offset); assertEquals( new OpenSearchScrollRequest( new OpenSearchRequest.IndexName("test"), new SearchSourceBuilder() .from(offset) - .size(maxResultWindow - offset) + .size(MAX_RESULT_WINDOW - offset) .timeout(DEFAULT_QUERY_TIMEOUT), - factory), - builder.build()); + exprValueFactory), + requestBuilder.build()); + } + + @Test + void testPushDownQuery() { + QueryBuilder query = QueryBuilders.termQuery("intA", 1); + requestBuilder.pushDown(query); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .query(query) + .sort(DOC_FIELD_NAME, ASC), + requestBuilder.getSourceBuilder() + ); + } + + @Test + void testPushDownAggregation() { + AggregationBuilder aggBuilder = AggregationBuilders.composite( + "composite_buckets", + Collections.singletonList(new TermsValuesSourceBuilder("longA"))); + OpenSearchAggregationResponseParser responseParser = + new CompositeAggregationParser( + new SingleValueParser("AVG(intA)")); + requestBuilder.pushDownAggregation(Pair.of(List.of(aggBuilder), responseParser)); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(0) + .timeout(DEFAULT_QUERY_TIMEOUT) + .aggregation(aggBuilder), + requestBuilder.getSourceBuilder() + ); + verify(exprValueFactory).setParser(responseParser); + } + + @Test + void testPushDownQueryAndSort() { + QueryBuilder query = QueryBuilders.termQuery("intA", 1); + requestBuilder.pushDown(query); + + FieldSortBuilder sortBuilder = SortBuilders.fieldSort("intA"); + requestBuilder.pushDownSort(List.of(sortBuilder)); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .query(query) + .sort(sortBuilder), + requestBuilder.getSourceBuilder()); + } + + @Test + void testPushDownSort() { + FieldSortBuilder sortBuilder = SortBuilders.fieldSort("intA"); + requestBuilder.pushDownSort(List.of(sortBuilder)); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .sort(sortBuilder), + requestBuilder.getSourceBuilder()); + } + + @Test + void testPushDownNonFieldSort() { + ScoreSortBuilder sortBuilder = SortBuilders.scoreSort(); + requestBuilder.pushDownSort(List.of(sortBuilder)); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .sort(sortBuilder), + requestBuilder.getSourceBuilder()); + } + + @Test + void testPushDownMultipleSort() { + requestBuilder.pushDownSort(List.of( + SortBuilders.fieldSort("intA"), + SortBuilders.fieldSort("intB"))); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .sort(SortBuilders.fieldSort("intA")) + .sort(SortBuilders.fieldSort("intB")), + requestBuilder.getSourceBuilder()); + } + + @Test + void testPushDownProject() { + Set references = Set.of(DSL.ref("intA", INTEGER)); + requestBuilder.pushDownProjects(references); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(DEFAULT_LIMIT) + .timeout(DEFAULT_QUERY_TIMEOUT) + .fetchSource(new String[]{"intA"}, new String[0]), + requestBuilder.getSourceBuilder()); + } + + @Test + void testPushTypeMapping() { + Map typeMapping = Map.of("intA", INTEGER); + requestBuilder.pushTypeMapping(typeMapping); + + verify(exprValueFactory).setTypeMapping(typeMapping); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java index a74c5fcbd4..d7e5955491 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchDefaultImplementorTest.java @@ -6,12 +6,7 @@ package org.opensearch.sql.opensearch.storage; -import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.Mockito.verify; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.ExtendWith; @@ -20,9 +15,7 @@ import org.mockito.Mockito; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.sql.opensearch.client.OpenSearchClient; -import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; import org.opensearch.sql.planner.logical.LogicalAD; -import org.opensearch.sql.planner.logical.LogicalHighlight; import org.opensearch.sql.planner.logical.LogicalML; import org.opensearch.sql.planner.logical.LogicalMLCommons; import org.opensearch.sql.planner.logical.LogicalPlan; @@ -31,41 +24,20 @@ @ExtendWith(MockitoExtension.class) public class OpenSearchDefaultImplementorTest { - @Mock - OpenSearchIndexScan indexScan; @Mock OpenSearchClient client; @Mock Table table; - /** - * For test coverage. - */ - @Test - public void visitInvalidTypeShouldThrowException() { - final OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); - - final IllegalStateException exception = - assertThrows(IllegalStateException.class, - () -> implementor.visitNode(relation("index", table), - indexScan)); - ; - assertEquals( - "unexpected plan node type " - + "class org.opensearch.sql.planner.logical.LogicalRelation", - exception.getMessage()); - } - @Test public void visitMachineLearning() { LogicalMLCommons node = Mockito.mock(LogicalMLCommons.class, Answers.RETURNS_DEEP_STUBS); Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); - assertNotNull(implementor.visitMLCommons(node, indexScan)); + new OpenSearchIndex.OpenSearchDefaultImplementor(client); + assertNotNull(implementor.visitMLCommons(node, null)); } @Test @@ -74,8 +46,8 @@ public void visitAD() { Answers.RETURNS_DEEP_STUBS); Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); - assertNotNull(implementor.visitAD(node, indexScan)); + new OpenSearchIndex.OpenSearchDefaultImplementor(client); + assertNotNull(implementor.visitAD(node, null)); } @Test @@ -84,21 +56,7 @@ public void visitML() { Answers.RETURNS_DEEP_STUBS); Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); - assertNotNull(implementor.visitML(node, indexScan)); - } - - @Test - public void visitHighlight() { - LogicalHighlight node = Mockito.mock(LogicalHighlight.class, - Answers.RETURNS_DEEP_STUBS); - Mockito.when(node.getChild().get(0)).thenReturn(Mockito.mock(LogicalPlan.class)); - OpenSearchRequestBuilder requestBuilder = Mockito.mock(OpenSearchRequestBuilder.class); - Mockito.when(indexScan.getRequestBuilder()).thenReturn(requestBuilder); - OpenSearchIndex.OpenSearchDefaultImplementor implementor = - new OpenSearchIndex.OpenSearchDefaultImplementor(indexScan, client); - - implementor.visitHighlight(node, indexScan); - verify(requestBuilder).pushDownHighlight(any(), any()); + new OpenSearchIndex.OpenSearchDefaultImplementor(client); + assertNotNull(implementor.visitML(node, null)); } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java index 9e375aa1b0..74c18f7c3d 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/OpenSearchIndexTest.java @@ -9,8 +9,6 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.allOf; -import static org.hamcrest.Matchers.arrayContaining; -import static org.hamcrest.Matchers.emptyArray; import static org.hamcrest.Matchers.hasEntry; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -23,14 +21,7 @@ import static org.opensearch.sql.expression.DSL.named; import static org.opensearch.sql.expression.DSL.ref; import static org.opensearch.sql.opensearch.data.type.OpenSearchDataType.OPENSEARCH_TEXT_KEYWORD; -import static org.opensearch.sql.opensearch.utils.Utils.indexScan; -import static org.opensearch.sql.opensearch.utils.Utils.indexScanAgg; -import static org.opensearch.sql.opensearch.utils.Utils.noProjects; -import static org.opensearch.sql.opensearch.utils.Utils.projects; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.eval; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; -import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.remove; @@ -49,13 +40,11 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; -import org.opensearch.search.fetch.subphase.FetchSourceContext; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; -import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.NamedExpression; import org.opensearch.sql.expression.ReferenceExpression; @@ -67,12 +56,7 @@ import org.opensearch.sql.opensearch.mapping.IndexMapping; import org.opensearch.sql.planner.logical.LogicalPlan; import org.opensearch.sql.planner.logical.LogicalPlanDSL; -import org.opensearch.sql.planner.physical.AggregationOperator; -import org.opensearch.sql.planner.physical.FilterOperator; -import org.opensearch.sql.planner.physical.LimitOperator; -import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; -import org.opensearch.sql.planner.physical.ProjectOperator; import org.opensearch.sql.storage.Table; @ExtendWith(MockitoExtension.class) @@ -144,25 +128,28 @@ void getFieldTypes() { .put("blob", "binary") .build()))); - Map fieldTypes = index.getFieldTypes(); - assertThat( - fieldTypes, - allOf( - aMapWithSize(13), - hasEntry("name", ExprCoreType.STRING), - hasEntry("address", (ExprType) OpenSearchDataType.OPENSEARCH_TEXT), - hasEntry("age", ExprCoreType.INTEGER), - hasEntry("account_number", ExprCoreType.LONG), - hasEntry("balance1", ExprCoreType.FLOAT), - hasEntry("balance2", ExprCoreType.DOUBLE), - hasEntry("gender", ExprCoreType.BOOLEAN), - hasEntry("family", ExprCoreType.ARRAY), - hasEntry("employer", ExprCoreType.STRUCT), - hasEntry("birthday", ExprCoreType.TIMESTAMP), - hasEntry("id1", ExprCoreType.BYTE), - hasEntry("id2", ExprCoreType.SHORT), - hasEntry("blob", (ExprType) OpenSearchDataType.OPENSEARCH_BINARY) - )); + // Run more than once to confirm caching logic is covered and can work + for (int i = 0; i < 2; i++) { + Map fieldTypes = index.getFieldTypes(); + assertThat( + fieldTypes, + allOf( + aMapWithSize(13), + hasEntry("name", ExprCoreType.STRING), + hasEntry("address", (ExprType) OpenSearchDataType.OPENSEARCH_TEXT), + hasEntry("age", ExprCoreType.INTEGER), + hasEntry("account_number", ExprCoreType.LONG), + hasEntry("balance1", ExprCoreType.FLOAT), + hasEntry("balance2", ExprCoreType.DOUBLE), + hasEntry("gender", ExprCoreType.BOOLEAN), + hasEntry("family", ExprCoreType.ARRAY), + hasEntry("employer", ExprCoreType.STRUCT), + hasEntry("birthday", ExprCoreType.TIMESTAMP), + hasEntry("id1", ExprCoreType.BYTE), + hasEntry("id2", ExprCoreType.SHORT), + hasEntry("blob", (ExprType) OpenSearchDataType.OPENSEARCH_BINARY) + )); + } } @Test @@ -170,7 +157,7 @@ void implementRelationOperatorOnly() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - LogicalPlan plan = relation(indexName, table); + LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); assertEquals( new OpenSearchIndexScan(client, settings, indexName, maxResultWindow, exprValueFactory), @@ -182,7 +169,7 @@ void implementRelationOperatorWithOptimization() { when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - LogicalPlan plan = relation(indexName, table); + LogicalPlan plan = index.createScanBuilder(); Integer maxResultWindow = index.getMaxResultWindow(); assertEquals( new OpenSearchIndexScan(client, settings, indexName, maxResultWindow, exprValueFactory), @@ -217,7 +204,7 @@ void implementOtherLogicalOperators() { eval( remove( rename( - relation(indexName, table), + index.createScanBuilder(), mappings), exclude), newEvalField), @@ -243,214 +230,4 @@ void implementOtherLogicalOperators() { include), index.implement(plan)); } - - @Test - void shouldImplLogicalIndexScan() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - NamedExpression named = named("n", field); - Expression filterExpr = DSL.equal(field, literal("John")); - - PhysicalPlan plan = index.implement( - project( - indexScan( - indexName, - filterExpr - ), - named)); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof OpenSearchIndexScan); - } - - @Test - void shouldNotPushDownFilterFarFromRelation() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - Expression filterExpr = DSL.equal(field, literal("John")); - List groupByExprs = Arrays.asList(named("age", ref("age", INTEGER))); - List aggregators = - Arrays.asList(named("avg(age)", new AvgAggregator(Arrays.asList(ref("age", INTEGER)), - DOUBLE))); - - PhysicalPlan plan = index.implement( - filter( - aggregation( - relation(indexName, table), - aggregators, - groupByExprs - ), - filterExpr)); - - assertTrue(plan instanceof FilterOperator); - } - - @Test - void shouldImplLogicalIndexScanAgg() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - Expression filterExpr = DSL.equal(field, literal("John")); - List groupByExprs = Arrays.asList(named("age", ref("age", INTEGER))); - List aggregators = - Arrays.asList(named("avg(age)", new AvgAggregator(Arrays.asList(ref("age", INTEGER)), - DOUBLE))); - - // IndexScanAgg without Filter - PhysicalPlan plan = index.implement( - filter( - indexScanAgg( - indexName, - aggregators, - groupByExprs - ), - filterExpr)); - - assertTrue(plan.getChild().get(0) instanceof OpenSearchIndexScan); - - // IndexScanAgg with Filter - plan = index.implement( - indexScanAgg( - indexName, - filterExpr, - aggregators, - groupByExprs)); - assertTrue(plan instanceof OpenSearchIndexScan); - } - - @Test - void shouldNotPushDownAggregationFarFromRelation() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - Expression filterExpr = DSL.equal(field, literal("John")); - List groupByExprs = Arrays.asList(named("age", ref("age", INTEGER))); - List aggregators = - Arrays.asList(named("avg(age)", new AvgAggregator(Arrays.asList(ref("age", INTEGER)), - DOUBLE))); - - PhysicalPlan plan = index.implement( - aggregation( - filter(filter( - relation(indexName, table), - filterExpr), filterExpr), - aggregators, - groupByExprs)); - assertTrue(plan instanceof AggregationOperator); - } - - @Test - void shouldImplIndexScanWithSort() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - NamedExpression named = named("n", field); - Expression sortExpr = ref("name", STRING); - - PhysicalPlan plan = index.implement( - project( - indexScan( - indexName, - Pair.of(Sort.SortOption.DEFAULT_ASC, sortExpr) - ), - named)); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof OpenSearchIndexScan); - } - - @Test - void shouldImplIndexScanWithLimit() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - NamedExpression named = named("n", field); - - PhysicalPlan plan = index.implement( - project( - indexScan( - indexName, - 1, 1, noProjects() - ), - named)); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof OpenSearchIndexScan); - } - - @Test - void shouldImplIndexScanWithSortAndLimit() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - ReferenceExpression field = ref("name", STRING); - NamedExpression named = named("n", field); - Expression sortExpr = ref("name", STRING); - - PhysicalPlan plan = index.implement( - project( - indexScan( - indexName, - sortExpr, - 1, 1, - noProjects() - ), - named)); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof OpenSearchIndexScan); - } - - @Test - void shouldNotPushDownLimitFarFromRelationButUpdateScanSize() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - PhysicalPlan plan = index.implement(index.optimize( - project( - limit( - sort( - relation("test", table), - Pair.of(Sort.SortOption.DEFAULT_ASC, - DSL.abs(named("intV", ref("intV", INTEGER)))) - ), - 300, 1 - ), - named("intV", ref("intV", INTEGER)) - ) - )); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof LimitOperator); - } - - @Test - void shouldPushDownProjects() { - when(settings.getSettingValue(Settings.Key.QUERY_SIZE_LIMIT)).thenReturn(200); - when(client.getIndexMaxResultWindows("test")).thenReturn(Map.of("test", 10000)); - - PhysicalPlan plan = index.implement( - project( - indexScan( - indexName, projects(ref("intV", INTEGER)) - ), - named("i", ref("intV", INTEGER)))); - - assertTrue(plan instanceof ProjectOperator); - assertTrue(((ProjectOperator) plan).getInput() instanceof OpenSearchIndexScan); - - final FetchSourceContext fetchSource = - ((OpenSearchIndexScan) ((ProjectOperator) plan).getInput()).getRequestBuilder() - .getSourceBuilder().fetchSource(); - assertThat(fetchSource.includes(), arrayContaining("intV")); - assertThat(fetchSource.excludes(), emptyArray()); - } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java new file mode 100644 index 0000000000..363727cbd3 --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/OpenSearchIndexScanOptimizationTest.java @@ -0,0 +1,609 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.opensearch.storage.scan; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_FIRST; +import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; +import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; +import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.data.type.ExprCoreType.LONG; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.limit; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.sort; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_AGGREGATION; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_FILTER; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_HIGHLIGHT; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_LIMIT; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_PROJECT; +import static org.opensearch.sql.planner.optimizer.rule.read.TableScanPushDown.PUSH_DOWN_SORT; + +import com.google.common.collect.ImmutableList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import lombok.Builder; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.index.query.QueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.aggregations.bucket.composite.CompositeAggregationBuilder; +import org.opensearch.search.aggregations.bucket.composite.TermsValuesSourceBuilder; +import org.opensearch.search.sort.SortBuilder; +import org.opensearch.search.sort.SortBuilders; +import org.opensearch.search.sort.SortOrder; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.HighlightExpression; +import org.opensearch.sql.expression.ReferenceExpression; +import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; +import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.response.agg.SingleValueParser; +import org.opensearch.sql.opensearch.storage.OpenSearchIndexScan; +import org.opensearch.sql.opensearch.storage.script.aggregation.AggregationQueryBuilder; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.optimizer.LogicalPlanOptimizer; +import org.opensearch.sql.planner.optimizer.rule.read.CreateTableScanBuilder; +import org.opensearch.sql.storage.Table; + + +@ExtendWith(MockitoExtension.class) +class OpenSearchIndexScanOptimizationTest { + + @Mock + private Table table; + + @Mock + private OpenSearchIndexScan indexScan; + + private OpenSearchIndexScanBuilder indexScanBuilder; + + @Mock + private OpenSearchRequestBuilder requestBuilder; + + private Runnable[] verifyPushDownCalls = {}; + + @BeforeEach + void setUp() { + indexScanBuilder = new OpenSearchIndexScanBuilder(indexScan); + when(table.createScanBuilder()).thenReturn(indexScanBuilder); + when(indexScan.getRequestBuilder()).thenReturn(requestBuilder); + } + + @Test + void test_project_push_down() { + assertEqualsAfterOptimization( + project( + indexScanAggBuilder( + withProjectPushedDown(DSL.ref("intV", INTEGER))), + DSL.named("i", DSL.ref("intV", INTEGER)) + ), + project( + relation("schema", table), + DSL.named("i", DSL.ref("intV", INTEGER))) + ); + } + + /** + * SELECT intV as i FROM schema WHERE intV = 1. + */ + @Test + void test_filter_push_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder( + //withProjectPushedDown(DSL.ref("intV", INTEGER)), + withFilterPushedDown(QueryBuilders.termQuery("intV", 1)) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ), + project( + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ) + ); + } + + /** + * SELECT avg(intV) FROM schema GROUP BY string_value. + */ + @Test + void test_aggregation_push_down() { + assertEqualsAfterOptimization( + project( + indexScanAggBuilder( + withAggregationPushedDown( + aggregate("AVG(intV)") + .aggregateBy("intV") + .groupBy("longV") + .resultTypes(Map.of( + "AVG(intV)", DOUBLE, + "longV", LONG)))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), + project( + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG)))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)) + ) + ); + } + + /* + @Disabled("This test should be enabled once https://github.com/opensearch-project/sql/issues/912 is fixed") + @Test + void aggregation_cant_merge_indexScan_with_project() { + assertEquals( + aggregation( + OpenSearchLogicalIndexScan.builder().relationName("schema") + .filter(DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))) + .projectList(ImmutableSet.of(DSL.ref("intV", INTEGER))) + .build(), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", + DSL.abs(DSL.ref("longV", LONG))))), + optimize( + aggregation( + OpenSearchLogicalIndexScan.builder().relationName("schema") + .filter(DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1)))) + .projectList( + ImmutableSet.of(DSL.ref("intV", INTEGER))) + .build(), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", + DSL.abs(DSL.ref("longV", LONG)))))) + ); + } + */ + + /** + * Sort - Relation --> IndexScan. + */ + @Test + void test_sort_push_down() { + assertEqualsAfterOptimization( + indexScanBuilder( + withSortPushedDown( + SortBuilders.fieldSort("intV").order(SortOrder.ASC).missing("_first")) + ), + sort( + relation("schema", table), + Pair.of(SortOption.DEFAULT_ASC, DSL.ref("intV", INTEGER)) + ) + ); + } + + @Test + void test_limit_push_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder( + withLimitPushedDown(1, 1)), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ), + project( + limit( + relation("schema", table), + 1, 1), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ) + ); + } + + @Test + void test_highlight_push_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder( + withHighlightPushedDown("*", Collections.emptyMap())), + DSL.named("highlight(*)", + new HighlightExpression(DSL.literal("*"))) + ), + project( + highlight( + relation("schema", table), + DSL.literal("*"), Collections.emptyMap()), + DSL.named("highlight(*)", + new HighlightExpression(DSL.literal("*"))) + ) + ); + } + + /** + * SELECT avg(intV) FROM schema WHERE intV = 1 GROUP BY string_value. + */ + @Test + void test_aggregation_filter_push_down() { + assertEqualsAfterOptimization( + project( + indexScanAggBuilder( + withFilterPushedDown(QueryBuilders.termQuery("intV", 1)), + withAggregationPushedDown( + aggregate("AVG(intV)") + .aggregateBy("intV") + .groupBy("longV") + .resultTypes(Map.of( + "AVG(intV)", DOUBLE, + "longV", LONG)))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)) + ), + project( + aggregation( + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) + ), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG)))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)) + ) + ); + } + + /** + * Sort - Filter - Relation --> IndexScan. + */ + @Test + void test_sort_filter_push_down() { + assertEqualsAfterOptimization( + indexScanBuilder( + withFilterPushedDown(QueryBuilders.termQuery("intV", 1)), + withSortPushedDown( + SortBuilders.fieldSort("longV").order(SortOrder.ASC).missing("_first")) + ), + sort( + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) + ), + Pair.of(SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) + ) + ); + } + + /** + * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY stringV. + */ + @Test + void test_sort_aggregation_push_down() { + assertEqualsAfterOptimization( + project( + indexScanAggBuilder( + withAggregationPushedDown( + aggregate("AVG(intV)") + .aggregateBy("intV") + .groupBy("stringV") + .sortBy(SortOption.DEFAULT_DESC) + .resultTypes(Map.of( + "AVG(intV)", DOUBLE, + "stringV", STRING)))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), + project( + sort( + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), + Pair.of(SortOption.DEFAULT_DESC, DSL.ref("stringV", STRING)) + ), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)) + ) + ); + } + + @Test + void test_limit_sort_filter_push_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder( + withFilterPushedDown(QueryBuilders.termQuery("intV", 1)), + withSortPushedDown( + SortBuilders.fieldSort("longV").order(SortOrder.ASC).missing("_first")), + withLimitPushedDown(1, 1)), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ), + project( + limit( + sort( + filter( + relation("schema", table), + DSL.equal(DSL.ref("intV", INTEGER), DSL.literal(integerValue(1))) + ), + Pair.of(SortOption.DEFAULT_ASC, DSL.ref("longV", LONG)) + ), 1, 1 + ), + DSL.named("intV", DSL.ref("intV", INTEGER)) + ) + ); + } + + /* + * Project(ListA) -> Project(ListB) -> Relation. + * -- will be optimized as + * Project(ListA) -> Project(ListB) -> Relation(project=ListB). + */ + @Test + void only_one_project_should_be_push() { + assertEqualsAfterOptimization( + project( + project( + indexScanBuilder( + withProjectPushedDown( + DSL.ref("intV", INTEGER), + DSL.ref("stringV", STRING))), + DSL.named("i", DSL.ref("intV", INTEGER)), + DSL.named("s", DSL.ref("stringV", STRING)) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ), + project( + project( + relation("schema", table), + DSL.named("i", DSL.ref("intV", INTEGER)), + DSL.named("s", DSL.ref("stringV", STRING)) + ), + DSL.named("i", DSL.ref("intV", INTEGER)) + ) + ); + } + + @Test + void sort_with_expression_cannot_merge_with_relation() { + assertEqualsAfterOptimization( + sort( + indexScanBuilder(), + Pair.of(SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) + ), + sort( + relation("schema", table), + Pair.of(SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) + ) + ); + } + + @Test + void sort_with_expression_cannot_merge_with_aggregation() { + assertEqualsAfterOptimization( + sort( + indexScanAggBuilder( + withAggregationPushedDown( + aggregate("AVG(intV)") + .aggregateBy("intV") + .groupBy("stringV") + .resultTypes(Map.of( + "AVG(intV)", DOUBLE, + "stringV", STRING)))), + Pair.of(SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) + ), + sort( + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), + Pair.of(SortOption.DEFAULT_ASC, DSL.abs(DSL.ref("intV", INTEGER))) + ) + ); + } + + @Test + void aggregation_cant_merge_index_scan_with_limit() { + assertEqualsAfterOptimization( + project( + aggregation( + indexScanBuilder( + withLimitPushedDown(10, 0)), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", + DSL.abs(DSL.ref("longV", LONG))))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), + project( + aggregation( + limit( + relation("schema", table), + 10, 0), + ImmutableList + .of(DSL.named("AVG(intV)", + DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("longV", + DSL.abs(DSL.ref("longV", LONG))))), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)))); + } + + /** + * Can't Optimize the following query. + * SELECT avg(intV) FROM schema GROUP BY stringV ORDER BY avg(intV). + */ + @Test + void sort_refer_to_aggregator_should_not_merge_with_indexAgg() { + assertEqualsAfterOptimization( + project( + sort( + indexScanAggBuilder( + withAggregationPushedDown( + aggregate("AVG(intV)") + .aggregateBy("intV") + .groupBy("stringV") + .resultTypes(Map.of( + "AVG(intV)", DOUBLE, + "stringV", STRING)))), + Pair.of(SortOption.DEFAULT_ASC, DSL.ref("AVG(intV)", INTEGER)) + ), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE))), + project( + sort( + aggregation( + relation("schema", table), + ImmutableList + .of(DSL.named("AVG(intV)", DSL.avg(DSL.ref("intV", INTEGER)))), + ImmutableList.of(DSL.named("stringV", DSL.ref("stringV", STRING)))), + Pair.of(SortOption.DEFAULT_ASC, DSL.ref("AVG(intV)", INTEGER)) + ), + DSL.named("AVG(intV)", DSL.ref("AVG(intV)", DOUBLE)) + ) + ); + } + + @Test + void project_literal_should_not_be_pushed_down() { + assertEqualsAfterOptimization( + project( + indexScanBuilder(), + DSL.named("i", DSL.literal("str")) + ), + optimize( + project( + relation("schema", table), + DSL.named("i", DSL.literal("str")) + ) + ) + ); + } + + private OpenSearchIndexScanBuilder indexScanBuilder(Runnable... verifyPushDownCalls) { + this.verifyPushDownCalls = verifyPushDownCalls; + return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanQueryBuilder(indexScan)); + } + + private OpenSearchIndexScanBuilder indexScanAggBuilder(Runnable... verifyPushDownCalls) { + this.verifyPushDownCalls = verifyPushDownCalls; + return new OpenSearchIndexScanBuilder(new OpenSearchIndexScanAggregationBuilder(indexScan)); + } + + private void assertEqualsAfterOptimization(LogicalPlan expected, LogicalPlan actual) { + assertEquals(expected, optimize(actual)); + + // Trigger build to make sure all push down actually happened in scan builder + indexScanBuilder.build(); + + // Verify to make sure all push down methods are called as expected + if (verifyPushDownCalls.length == 0) { + reset(indexScan); + } else { + Arrays.stream(verifyPushDownCalls).forEach(Runnable::run); + } + } + + private Runnable withFilterPushedDown(QueryBuilder filteringCondition) { + return () -> verify(requestBuilder, times(1)).pushDown(filteringCondition); + } + + private Runnable withAggregationPushedDown( + AggregationAssertHelper.AggregationAssertHelperBuilder aggregation) { + + // Assume single term bucket and AVG metric in all tests in this suite + CompositeAggregationBuilder aggBuilder = AggregationBuilders.composite( + "composite_buckets", + Collections.singletonList( + new TermsValuesSourceBuilder(aggregation.groupBy) + .field(aggregation.groupBy) + .order(aggregation.sortBy.getSortOrder() == ASC ? "asc" : "desc") + .missingOrder(aggregation.sortBy.getNullOrder() == NULL_FIRST ? "first" : "last") + .missingBucket(true))) + .subAggregation( + AggregationBuilders.avg(aggregation.aggregateName) + .field(aggregation.aggregateBy)) + .size(AggregationQueryBuilder.AGGREGATION_BUCKET_SIZE); + + List aggBuilders = Collections.singletonList(aggBuilder); + OpenSearchAggregationResponseParser responseParser = + new CompositeAggregationParser( + new SingleValueParser(aggregation.aggregateName)); + + return () -> { + verify(requestBuilder, times(1)).pushDownAggregation(Pair.of(aggBuilders, responseParser)); + verify(requestBuilder, times(1)).pushTypeMapping(aggregation.resultTypes); + }; + } + + private Runnable withSortPushedDown(SortBuilder... sorts) { + return () -> verify(requestBuilder, times(1)).pushDownSort(Arrays.asList(sorts)); + } + + private Runnable withLimitPushedDown(int size, int offset) { + return () -> verify(requestBuilder, times(1)).pushDownLimit(size, offset); + } + + private Runnable withProjectPushedDown(ReferenceExpression... references) { + return () -> verify(requestBuilder, times(1)).pushDownProjects( + new HashSet<>(Arrays.asList(references))); + } + + private Runnable withHighlightPushedDown(String field, Map arguments) { + return () -> verify(requestBuilder, times(1)).pushDownHighlight(field, arguments); + } + + private static AggregationAssertHelper.AggregationAssertHelperBuilder aggregate(String aggName) { + var aggBuilder = new AggregationAssertHelper.AggregationAssertHelperBuilder(); + aggBuilder.aggregateName = aggName; + aggBuilder.sortBy = SortOption.DEFAULT_ASC; + return aggBuilder; + } + + /** Assertion helper for readability. */ + @Builder + private static class AggregationAssertHelper { + + String aggregateName; + + String aggregateBy; + + String groupBy; + + SortOption sortBy; + + Map resultTypes; + } + + private LogicalPlan optimize(LogicalPlan plan) { + LogicalPlanOptimizer optimizer = new LogicalPlanOptimizer(List.of( + new CreateTableScanBuilder(), + PUSH_DOWN_FILTER, + PUSH_DOWN_AGGREGATION, + PUSH_DOWN_SORT, + PUSH_DOWN_LIMIT, + PUSH_DOWN_HIGHLIGHT, + PUSH_DOWN_PROJECT)); + return optimizer.optimize(plan); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/utils/Utils.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/utils/Utils.java index 2ed9a16434..85b8889de3 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/utils/Utils.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/utils/Utils.java @@ -20,141 +20,10 @@ import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.AvgAggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexAgg; -import org.opensearch.sql.opensearch.planner.logical.OpenSearchLogicalIndexScan; -import org.opensearch.sql.planner.logical.LogicalPlan; @UtilityClass public class Utils { - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, Expression filter) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .filter(filter) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, - Pair... sorts) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .sortList(Arrays.asList(sorts)) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, - Expression filter, - Pair... sorts) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .filter(filter) - .sortList(Arrays.asList(sorts)) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, Integer offset, Integer limit, - Set projectList) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .offset(offset) - .limit(limit) - .projectList(projectList) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, - Expression filter, - Integer offset, Integer limit, - Set projectList) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .filter(filter) - .offset(offset) - .limit(limit) - .projectList(projectList) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, - Expression filter, - Integer offset, Integer limit, - List> sorts, - Set projectList) { - return OpenSearchLogicalIndexScan.builder().relationName(tableName) - .filter(filter) - .sortList(sorts) - .offset(offset) - .limit(limit) - .projectList(projectList) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, - Set projects) { - return OpenSearchLogicalIndexScan.builder() - .relationName(tableName) - .projectList(projects) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexScan. - */ - public static LogicalPlan indexScan(String tableName, Expression filter, - Set projects) { - return OpenSearchLogicalIndexScan.builder() - .relationName(tableName) - .filter(filter) - .projectList(projects) - .build(); - } - - /** - * Build ElasticsearchLogicalIndexAgg. - */ - public static LogicalPlan indexScanAgg(String tableName, List aggregators, - List groupByList) { - return OpenSearchLogicalIndexAgg.builder().relationName(tableName) - .aggregatorList(aggregators).groupByList(groupByList).build(); - } - - /** - * Build ElasticsearchLogicalIndexAgg. - */ - public static LogicalPlan indexScanAgg(String tableName, List aggregators, - List groupByList, - List> sortList) { - return OpenSearchLogicalIndexAgg.builder().relationName(tableName) - .aggregatorList(aggregators).groupByList(groupByList).sortList(sortList).build(); - } - - /** - * Build ElasticsearchLogicalIndexAgg. - */ - public static LogicalPlan indexScanAgg(String tableName, - Expression filter, - List aggregators, - List groupByList) { - return OpenSearchLogicalIndexAgg.builder().relationName(tableName).filter(filter) - .aggregatorList(aggregators).groupByList(groupByList).build(); - } - public static AvgAggregator avg(Expression expr, ExprCoreType type) { return new AvgAggregator(Arrays.asList(expr), type); } From 2af7321065d2b21845c04146de13e46f97566514 Mon Sep 17 00:00:00 2001 From: Guian Gumpac Date: Wed, 7 Dec 2022 15:59:43 -0800 Subject: [PATCH 3/5] Add support for wildcard_query function to the new engine (#156) (#1108) Signed-off-by: Guian Gumpac --- .../org/opensearch/sql/expression/DSL.java | 4 + .../function/BuiltinFunctionName.java | 4 +- .../function/OpenSearchFunctions.java | 7 + .../sql/analysis/ExpressionAnalyzerTest.java | 28 +++ .../function/OpenSearchFunctionsTest.java | 8 + docs/user/dql/functions.rst | 55 ++++++ doctest/test_data/wildcard.json | 22 +++ doctest/test_docs.py | 4 +- doctest/test_mapping/wildcard.json | 9 + .../sql/legacy/SQLIntegTestCase.java | 6 +- .../opensearch/sql/legacy/TestsConstants.java | 1 + .../org/opensearch/sql/ppl/LikeQueryIT.java | 88 +++++++++ .../org/opensearch/sql/sql/LikeQueryIT.java | 140 ++++++++++++++ .../opensearch/sql/sql/WildcardQueryIT.java | 183 ++++++++++++++++++ .../wildcard_index_mappings.json | 21 ++ integ-test/src/test/resources/wildcard.json | 20 ++ .../storage/script/StringUtils.java | 54 ++++++ .../script/filter/FilterQueryBuilder.java | 7 +- .../script/filter/lucene/LikeQuery.java | 37 ++++ .../script/filter/lucene/WildcardQuery.java | 31 --- .../FunctionParameterRepository.java | 9 + .../lucene/relevance/WildcardQuery.java | 35 ++++ .../storage/script/StringUtilsTest.java | 29 +++ .../script/filter/FilterQueryBuilderTest.java | 124 ++++++++++++ .../filter/lucene/WildcardQueryTest.java | 94 +++++++++ sql/src/main/antlr/OpenSearchSQLLexer.g4 | 1 + sql/src/main/antlr/OpenSearchSQLParser.g4 | 3 +- .../sql/sql/antlr/SQLSyntaxParserTest.java | 16 +- .../sql/parser/AstExpressionBuilderTest.java | 13 ++ 29 files changed, 1015 insertions(+), 38 deletions(-) create mode 100644 doctest/test_data/wildcard.json create mode 100644 doctest/test_mapping/wildcard.json create mode 100644 integ-test/src/test/java/org/opensearch/sql/ppl/LikeQueryIT.java create mode 100644 integ-test/src/test/java/org/opensearch/sql/sql/LikeQueryIT.java create mode 100644 integ-test/src/test/java/org/opensearch/sql/sql/WildcardQueryIT.java create mode 100644 integ-test/src/test/resources/indexDefinitions/wildcard_index_mappings.json create mode 100644 integ-test/src/test/resources/wildcard.json create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/StringUtils.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/LikeQuery.java delete mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/WildcardQuery.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/WildcardQuery.java create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/StringUtilsTest.java create mode 100644 opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/WildcardQueryTest.java diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index f5fd1e3315..3b601f98a3 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -715,6 +715,10 @@ public static FunctionExpression match_bool_prefix(Expression... args) { return compile(FunctionProperties.None, BuiltinFunctionName.MATCH_BOOL_PREFIX, args); } + public static FunctionExpression wildcard_query(Expression... args) { + return compile(FunctionProperties.None,BuiltinFunctionName.WILDCARD_QUERY, args); + } + public static FunctionExpression now(FunctionProperties functionProperties, Expression... args) { return compile(functionProperties, BuiltinFunctionName.NOW, args); diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index b09f3b0c74..0b7701d8a9 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -228,7 +228,9 @@ public enum BuiltinFunctionName { MATCHQUERY(FunctionName.of("matchquery")), MULTI_MATCH(FunctionName.of("multi_match")), MULTIMATCH(FunctionName.of("multimatch")), - MULTIMATCHQUERY(FunctionName.of("multimatchquery")); + MULTIMATCHQUERY(FunctionName.of("multimatchquery")), + WILDCARDQUERY(FunctionName.of("wildcardquery")), + WILDCARD_QUERY(FunctionName.of("wildcard_query")); private final FunctionName name; diff --git a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java index 2041b9762e..d8efe42640 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/OpenSearchFunctions.java @@ -42,6 +42,8 @@ public void register(BuiltinFunctionRepository repository) { repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASE)); repository.register(match_phrase(BuiltinFunctionName.MATCHPHRASEQUERY)); repository.register(match_phrase_prefix()); + repository.register(wildcard_query(BuiltinFunctionName.WILDCARD_QUERY)); + repository.register(wildcard_query(BuiltinFunctionName.WILDCARDQUERY)); } private static FunctionResolver match_bool_prefix() { @@ -83,6 +85,11 @@ private static FunctionResolver query_string() { return new RelevanceFunctionResolver(funcName, STRUCT); } + private static FunctionResolver wildcard_query(BuiltinFunctionName wildcardQuery) { + FunctionName funcName = wildcardQuery.getName(); + return new RelevanceFunctionResolver(funcName, STRING); + } + public static class OpenSearchFunction extends FunctionExpression { private final FunctionName functionName; private final List arguments; diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index dfb7a7239f..7114b220ab 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -540,6 +540,34 @@ void query_string_expression_two_fields() { AstDSL.unresolvedArg("query", stringLiteral("query_value")))); } + @Test + void wildcard_query_expression() { + assertAnalyzeEqual( + DSL.wildcard_query( + DSL.namedArgument("field", DSL.literal("test")), + DSL.namedArgument("query", DSL.literal("query_value*"))), + AstDSL.function("wildcard_query", + unresolvedArg("field", stringLiteral("test")), + unresolvedArg("query", stringLiteral("query_value*")))); + } + + @Test + void wildcard_query_expression_all_params() { + assertAnalyzeEqual( + DSL.wildcard_query( + DSL.namedArgument("field", DSL.literal("test")), + DSL.namedArgument("query", DSL.literal("query_value*")), + DSL.namedArgument("boost", DSL.literal("1.5")), + DSL.namedArgument("case_insensitive", DSL.literal("true")), + DSL.namedArgument("rewrite", DSL.literal("scoring_boolean"))), + AstDSL.function("wildcard_query", + unresolvedArg("field", stringLiteral("test")), + unresolvedArg("query", stringLiteral("query_value*")), + unresolvedArg("boost", stringLiteral("1.5")), + unresolvedArg("case_insensitive", stringLiteral("true")), + unresolvedArg("rewrite", stringLiteral("scoring_boolean")))); + } + @Test public void match_phrase_prefix_all_params() { assertAnalyzeEqual( diff --git a/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionsTest.java b/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionsTest.java index 787ca016c9..6e4fff2fb0 100644 --- a/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionsTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/function/OpenSearchFunctionsTest.java @@ -197,4 +197,12 @@ void query_string() { fields.getValue(), query.getValue()), expr.toString()); } + + @Test + void wildcard_query() { + FunctionExpression expr = DSL.wildcard_query(field, query); + assertEquals(String.format("wildcard_query(field=%s, query=%s)", + field.getValue(), query.getValue()), + expr.toString()); + } } diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index c4c2fa988b..7be50ccffb 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -3299,6 +3299,59 @@ Example searching for field Tags:: | [Winnie-the-Pooh] | +----------------------------------------------+ +WILDCARD_QUERY +------------ + +Description +>>>>>>>>>>> + +``wildcard_query(field_expression, query_expression[, option=]*)`` + +The ``wildcard_query`` function maps to the ``wildcard_query`` query used in search engine. It returns documents that match provided text in the specified field. +OpenSearch supports wildcard characters ``*`` and ``?``. See the full description here: https://opensearch.org/docs/latest/opensearch/query-dsl/term/#wildcards. +You may include a backslash ``\`` to escape SQL wildcard characters ``\%`` and ``\_``. + +Available parameters include: + +- boost +- case_insensitive +- rewrite + +For backward compatibility, ``wildcardquery`` is also supported and mapped to ``wildcard_query`` query as well. + +Example with only ``field`` and ``query`` expressions, and all other parameters are set default values:: + + os> select Body from wildcard where wildcard_query(Body, 'test wildcard*'); + fetched rows / total rows = 7/7 + +-------------------------------------------+ + | Body | + |-------------------------------------------| + | test wildcard | + | test wildcard in the end of the text% | + | test wildcard in % the middle of the text | + | test wildcard %% beside each other | + | test wildcard in the end of the text_ | + | test wildcard in _ the middle of the text | + | test wildcard __ beside each other | + +-------------------------------------------+ + +Another example to show how to set custom values for the optional parameters:: + + os> select Body from wildcard where wildcard_query(Body, 'test wildcard*', boost=0.7, case_insensitive=true, rewrite='constant_score'); + fetched rows / total rows = 8/8 + +-------------------------------------------+ + | Body | + |-------------------------------------------| + | test wildcard | + | test wildcard in the end of the text% | + | test wildcard in % the middle of the text | + | test wildcard %% beside each other | + | test wildcard in the end of the text_ | + | test wildcard in _ the middle of the text | + | test wildcard __ beside each other | + | tEsT wIlDcArD sensitive cases | + +-------------------------------------------+ + System Functions ================ @@ -3323,3 +3376,5 @@ Example:: |----------------+---------------+-----------------+------------------| | DATE | INTEGER | DATETIME | STRUCT | +----------------+---------------+-----------------+------------------+ + + diff --git a/doctest/test_data/wildcard.json b/doctest/test_data/wildcard.json new file mode 100644 index 0000000000..c91778d8ab --- /dev/null +++ b/doctest/test_data/wildcard.json @@ -0,0 +1,22 @@ +{"index":{"_id":"0"}} +{"Body":"test wildcard"} +{"index":{"_id":"1"}} +{"Body":"test wildcard in the end of the text%"} +{"index":{"_id":"2"}} +{"Body":"%test wildcard in the beginning of the text"} +{"index":{"_id":"3"}} +{"Body":"test wildcard in % the middle of the text"} +{"index":{"_id":"4"}} +{"Body":"test wildcard %% beside each other"} +{"index":{"_id":"5"}} +{"Body":"test wildcard in the end of the text_"} +{"index":{"_id":"6"}} +{"Body":"_test wildcard in the beginning of the text"} +{"index":{"_id":"7"}} +{"Body":"test wildcard in _ the middle of the text"} +{"index":{"_id":"8"}} +{"Body":"test wildcard __ beside each other"} +{"index":{"_id":"9"}} +{"Body":"test backslash wildcard \\_"} +{"index":{"_id":"10"}} +{"Body":"tEsT wIlDcArD sensitive cases"} diff --git a/doctest/test_docs.py b/doctest/test_docs.py index 6d2538196a..b5edf46de9 100644 --- a/doctest/test_docs.py +++ b/doctest/test_docs.py @@ -26,6 +26,7 @@ NYC_TAXI = "nyc_taxi" BOOKS = "books" APACHE = "apache" +WILDCARD = "wildcard" class DocTestConnection(OpenSearchConnection): @@ -92,6 +93,7 @@ def set_up_test_indices(test): load_file("nyc_taxi.json", index_name=NYC_TAXI) load_file("books.json", index_name=BOOKS) load_file("apache.json", index_name=APACHE) + load_file("wildcard.json", index_name=WILDCARD) def load_file(filename, index_name): @@ -120,7 +122,7 @@ def set_up(test): def tear_down(test): # drop leftover tables after each test - test_data_client.indices.delete(index=[ACCOUNTS, EMPLOYEES, PEOPLE, ACCOUNT2, NYC_TAXI, BOOKS, APACHE], ignore_unavailable=True) + test_data_client.indices.delete(index=[ACCOUNTS, EMPLOYEES, PEOPLE, ACCOUNT2, NYC_TAXI, BOOKS, APACHE, WILDCARD], ignore_unavailable=True) docsuite = partial(doctest.DocFileSuite, diff --git a/doctest/test_mapping/wildcard.json b/doctest/test_mapping/wildcard.json new file mode 100644 index 0000000000..670a774ae1 --- /dev/null +++ b/doctest/test_mapping/wildcard.json @@ -0,0 +1,9 @@ +{ + "mappings" : { + "properties" : { + "Body" : { + "type" : "keyword" + } + } + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java index f03acbbbfd..80348b2a8b 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/SQLIntegTestCase.java @@ -584,7 +584,11 @@ public enum Index { CALCS(TestsConstants.TEST_INDEX_CALCS, "calcs", getMappingFile("calcs_index_mappings.json"), - "src/test/resources/calcs.json"),; + "src/test/resources/calcs.json"), + WILDCARD(TestsConstants.TEST_INDEX_WILDCARD, + "wildcard", + getMappingFile("wildcard_index_mappings.json"), + "src/test/resources/wildcard.json"),; private final String name; private final String type; diff --git a/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java b/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java index a9f81c68fe..aff269fcce 100644 --- a/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java +++ b/integ-test/src/test/java/org/opensearch/sql/legacy/TestsConstants.java @@ -53,6 +53,7 @@ public class TestsConstants { public final static String TEST_INDEX_BEER = TEST_INDEX + "_beer"; public final static String TEST_INDEX_NULL_MISSING = TEST_INDEX + "_null_missing"; public final static String TEST_INDEX_CALCS = TEST_INDEX + "_calcs"; + public final static String TEST_INDEX_WILDCARD = TEST_INDEX + "_wildcard"; public final static String DATE_FORMAT = "yyyy-MM-dd'T'HH:mm:ss.SSS'Z'"; public final static String TS_DATE_FORMAT = "yyyy-MM-dd HH:mm:ss.SSS"; diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/LikeQueryIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/LikeQueryIT.java new file mode 100644 index 0000000000..67ad553689 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/LikeQueryIT.java @@ -0,0 +1,88 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.ppl; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WILDCARD; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.Test; + +public class LikeQueryIT extends PPLIntegTestCase { + + @Override + public void init() throws IOException { + loadIndex(Index.WILDCARD); + } + + @Test + public void test_like_with_percent() throws IOException { + String query = "source=" + TEST_INDEX_WILDCARD + " | WHERE Like(KeywordBody, 'test wildcard%') | fields KeywordBody"; + JSONObject result = executeQuery(query); + verifyDataRows(result, + rows("test wildcard"), + rows("test wildcard in the end of the text%"), + rows("test wildcard in % the middle of the text"), + rows("test wildcard %% beside each other"), + rows("test wildcard in the end of the text_"), + rows("test wildcard in _ the middle of the text"), + rows("test wildcard __ beside each other")); + } + + @Test + public void test_like_with_escaped_percent() throws IOException { + String query = "source=" + TEST_INDEX_WILDCARD + " | WHERE Like(KeywordBody, '\\\\%test wildcard%') | fields KeywordBody"; + JSONObject result = executeQuery(query); + verifyDataRows(result, + rows("%test wildcard in the beginning of the text")); + } + + @Test + public void test_like_in_where_with_escaped_underscore() throws IOException { + String query = "source=" + TEST_INDEX_WILDCARD + " | WHERE Like(KeywordBody, '\\\\_test wildcard%') | fields KeywordBody"; + JSONObject result = executeQuery(query); + verifyDataRows(result, + rows("_test wildcard in the beginning of the text")); + } + + @Test + public void test_like_on_text_field_with_one_word() throws IOException { + String query = "source=" + TEST_INDEX_WILDCARD + " | WHERE Like(TextBody, 'test*') | fields TextBody"; + JSONObject result = executeQuery(query); + assertEquals(9, result.getInt("total")); + } + + @Test + public void test_like_on_text_keyword_field_with_one_word() throws IOException { + String query = "source=" + TEST_INDEX_WILDCARD + " | WHERE Like(TextKeywordBody, 'test*') | fields TextKeywordBody"; + JSONObject result = executeQuery(query); + assertEquals(8, result.getInt("total")); + } + + @Test + public void test_like_on_text_keyword_field_with_greater_than_one_word() throws IOException { + String query = "source=" + TEST_INDEX_WILDCARD + " | WHERE Like(TextKeywordBody, 'test wild*') | fields TextKeywordBody"; + JSONObject result = executeQuery(query); + assertEquals(7, result.getInt("total")); + } + + @Test + public void test_like_on_text_field_with_greater_than_one_word() throws IOException { + String query = "source=" + TEST_INDEX_WILDCARD + " | WHERE Like(TextBody, 'test wild*') | fields TextBody"; + JSONObject result = executeQuery(query); + assertEquals(0, result.getInt("total")); + } + + @Test + public void test_convert_field_text_to_keyword() throws IOException { + String query = "source=" + TEST_INDEX_WILDCARD + " | WHERE Like(TextKeywordBody, '*') | fields TextKeywordBody"; + String result = explainQueryToString(query); + assertTrue(result.contains("TextKeywordBody.keyword")); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/LikeQueryIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/LikeQueryIT.java new file mode 100644 index 0000000000..f0e82adb6f --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/LikeQueryIT.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.sql; + +import org.json.JSONObject; +import org.junit.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; + +import java.io.IOException; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WILDCARD; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; + +public class LikeQueryIT extends SQLIntegTestCase { + @Override + protected void init() throws Exception { + loadIndex(Index.WILDCARD); + } + + @Test + public void test_like_in_select() throws IOException { + String query = "SELECT KeywordBody, KeywordBody LIKE 'test wildcard%' FROM " + TEST_INDEX_WILDCARD; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, + rows("test wildcard", true), + rows("test wildcard in the end of the text%", true), + rows("%test wildcard in the beginning of the text", false), + rows("test wildcard in % the middle of the text", true), + rows("test wildcard %% beside each other", true), + rows("test wildcard in the end of the text_", true), + rows("_test wildcard in the beginning of the text", false), + rows("test wildcard in _ the middle of the text", true), + rows("test wildcard __ beside each other", true), + rows("test backslash wildcard \\_", false)); + } + + @Test + public void test_like_in_select_with_escaped_percent() throws IOException { + String query = "SELECT KeywordBody, KeywordBody LIKE '\\\\%test wildcard%' FROM " + TEST_INDEX_WILDCARD; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, + rows("test wildcard", false), + rows("test wildcard in the end of the text%", false), + rows("%test wildcard in the beginning of the text", true), + rows("test wildcard in % the middle of the text", false), + rows("test wildcard %% beside each other", false), + rows("test wildcard in the end of the text_", false), + rows("_test wildcard in the beginning of the text", false), + rows("test wildcard in _ the middle of the text", false), + rows("test wildcard __ beside each other", false), + rows("test backslash wildcard \\_", false)); + } + + @Test + public void test_like_in_select_with_escaped_underscore() throws IOException { + String query = "SELECT KeywordBody, KeywordBody LIKE '\\\\_test wildcard%' FROM " + TEST_INDEX_WILDCARD; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, + rows("test wildcard", false), + rows("test wildcard in the end of the text%", false), + rows("%test wildcard in the beginning of the text", false), + rows("test wildcard in % the middle of the text", false), + rows("test wildcard %% beside each other", false), + rows("test wildcard in the end of the text_", false), + rows("_test wildcard in the beginning of the text", true), + rows("test wildcard in _ the middle of the text", false), + rows("test wildcard __ beside each other", false), + rows("test backslash wildcard \\_", false)); + } + + @Test + public void test_like_in_where() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE KeywordBody LIKE 'test wildcard%'"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, + rows("test wildcard"), + rows("test wildcard in the end of the text%"), + rows("test wildcard in % the middle of the text"), + rows("test wildcard %% beside each other"), + rows("test wildcard in the end of the text_"), + rows("test wildcard in _ the middle of the text"), + rows("test wildcard __ beside each other")); + } + + @Test + public void test_like_in_where_with_escaped_percent() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE KeywordBody LIKE '\\\\%test wildcard%'"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, + rows("%test wildcard in the beginning of the text")); + } + + @Test + public void test_like_in_where_with_escaped_underscore() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE KeywordBody LIKE '\\\\_test wildcard%'"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, + rows("_test wildcard in the beginning of the text")); + } + + @Test + public void test_like_on_text_field_with_one_word() throws IOException { + String query = "SELECT * FROM " + TEST_INDEX_WILDCARD + " WHERE TextBody LIKE 'test*'"; + JSONObject result = executeJdbcRequest(query); + assertEquals(9, result.getInt("total")); + } + + @Test + public void test_like_on_text_keyword_field_with_one_word() throws IOException { + String query = "SELECT * FROM " + TEST_INDEX_WILDCARD + " WHERE TextKeywordBody LIKE 'test*'"; + JSONObject result = executeJdbcRequest(query); + assertEquals(8, result.getInt("total")); + } + + @Test + public void test_like_on_text_keyword_field_with_greater_than_one_word() throws IOException { + String query = "SELECT * FROM " + TEST_INDEX_WILDCARD + " WHERE TextKeywordBody LIKE 'test wild*'"; + JSONObject result = executeJdbcRequest(query); + assertEquals(7, result.getInt("total")); + } + + @Test + public void test_like_on_text_field_with_greater_than_one_word() throws IOException { + String query = "SELECT * FROM " + TEST_INDEX_WILDCARD + " WHERE TextBody LIKE 'test wild*'"; + JSONObject result = executeJdbcRequest(query); + assertEquals(0, result.getInt("total")); + } + + @Test + public void test_convert_field_text_to_keyword() throws IOException { + String query = "SELECT * FROM " + TEST_INDEX_WILDCARD + " WHERE TextKeywordBody LIKE '*'"; + String result = explainQuery(query); + assertTrue(result.contains("TextKeywordBody.keyword")); + } +} diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/WildcardQueryIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/WildcardQueryIT.java new file mode 100644 index 0000000000..ee636ed5ce --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/WildcardQueryIT.java @@ -0,0 +1,183 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.sql; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_WILDCARD; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.Test; +import org.opensearch.sql.legacy.SQLIntegTestCase; + +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; + +public class WildcardQueryIT extends SQLIntegTestCase { + @Override + protected void init() throws Exception { + loadIndex(Index.WILDCARD); + } + + @Test + public void test_wildcard_query_asterisk_function() throws IOException { + String expected = "test wildcard"; + + String query1 = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, 't*') LIMIT 1"; + JSONObject result1 = executeJdbcRequest(query1); + verifyDataRows(result1, rows(expected)); + + String query2 = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcardquery(KeywordBody, 't*') LIMIT 1"; + JSONObject result2 = executeJdbcRequest(query2); + verifyDataRows(result2, rows(expected)); + } + + @Test + public void test_wildcard_query_question_mark_function() throws IOException { + String expected = "test wildcard"; + + String query1 = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, 'test wild??rd')"; + JSONObject result1 = executeJdbcRequest(query1); + verifyDataRows(result1, rows(expected)); + + String query2 = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcardquery(KeywordBody, 'test wild??rd')"; + JSONObject result2 = executeJdbcRequest(query2); + verifyDataRows(result2, rows(expected)); + } + + // SQL uses ? as a wildcard which is converted to * in WildcardQuery.java + @Test + public void test_wildcard_query_sql_wildcard_percent_conversion() throws IOException { + String query1 = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, 'test%')"; + JSONObject result1 = executeJdbcRequest(query1); + assertEquals(8, result1.getInt("total")); + + String query2 = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, 'test*')"; + JSONObject result2 = executeJdbcRequest(query2); + assertEquals(result1.getInt("total"), result2.getInt("total")); + } + + // SQL uses _ as a wildcard which is converted to ? in WildcardQuery.java + @Test + public void test_wildcard_query_sql_wildcard_underscore_conversion() throws IOException { + String query1 = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, 'test wild_ard*')"; + JSONObject result1 = executeJdbcRequest(query1); + assertEquals(7, result1.getInt("total")); + + String query2 = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, 'test wild?ard*')"; + JSONObject result2 = executeJdbcRequest(query2); + assertEquals(result1.getInt("total"), result2.getInt("total")); + } + + @Test + public void test_escaping_wildcard_percent_in_the_beginning_of_text() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, '\\\\%*')"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, rows("%test wildcard in the beginning of the text")); + } + + @Test + public void test_escaping_wildcard_percent_in_text() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, '*\\\\%%')"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, rows("test wildcard in % the middle of the text"), + rows("test wildcard %% beside each other"), + rows("test wildcard in the end of the text%"), + rows("%test wildcard in the beginning of the text")); + } + + @Test + public void test_escaping_wildcard_percent_in_the_end_of_text() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, '*\\\\%')"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, rows("test wildcard in the end of the text%")); + } + + @Test + public void test_double_escaped_wildcard_percent() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, '*\\\\%\\\\%*')"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, rows("test wildcard %% beside each other")); + } + + @Test + public void test_escaping_wildcard_underscore_in_the_beginning_of_text() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, '\\\\_*')"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, rows("_test wildcard in the beginning of the text")); + } + + @Test + public void test_escaping_wildcard_underscore_in_text() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, '*\\\\_*')"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, rows("test wildcard in _ the middle of the text"), + rows("test wildcard __ beside each other"), + rows("test wildcard in the end of the text_"), + rows("_test wildcard in the beginning of the text"), + rows("test backslash wildcard \\_")); + } + + @Test + public void test_escaping_wildcard_underscore_in_the_end_of_text() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, '*\\\\_')"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, + rows("test wildcard in the end of the text_"), + rows("test backslash wildcard \\_")); + } + + @Test + public void test_double_escaped_wildcard_underscore() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, '*\\\\_\\\\_*')"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, rows("test wildcard __ beside each other")); + } + + @Test + public void test_backslash_wildcard() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(KeywordBody, '*\\\\\\\\\\\\_')"; + JSONObject result = executeJdbcRequest(query); + verifyDataRows(result, rows("test backslash wildcard \\_")); + } + + @Test + public void all_params_test() throws IOException { + String query = "SELECT KeywordBody FROM " + TEST_INDEX_WILDCARD + + " WHERE wildcard_query(KeywordBody, 'test*', boost = 0.9," + + " case_insensitive=true, rewrite='constant_score')"; + JSONObject result = executeJdbcRequest(query); + assertEquals(8, result.getInt("total")); + } + + @Test + public void test_wildcard_query_on_text_field_with_one_word() throws IOException { + String query = "SELECT * FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(TextBody, 'test*')"; + JSONObject result = executeJdbcRequest(query); + assertEquals(9, result.getInt("total")); + } + + @Test + public void test_wildcard_query_on_text_keyword_field_with_one_word() throws IOException { + String query = "SELECT * FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(TextKeywordBody, 'test*')"; + JSONObject result = executeJdbcRequest(query); + assertEquals(9, result.getInt("total")); + } + + @Test + public void test_wildcard_query_on_text_field_with_greater_than_one_word() throws IOException { + String query = "SELECT * FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(TextBody, 'test wild*')"; + JSONObject result = executeJdbcRequest(query); + assertEquals(0, result.getInt("total")); + } + + @Test + public void test_wildcard_query_on_text_keyword_field_with_greater_than_one_word() throws IOException { + String query = "SELECT * FROM " + TEST_INDEX_WILDCARD + " WHERE wildcard_query(TextKeywordBody, 'test wild*')"; + JSONObject result = executeJdbcRequest(query); + assertEquals(0, result.getInt("total")); + } +} diff --git a/integ-test/src/test/resources/indexDefinitions/wildcard_index_mappings.json b/integ-test/src/test/resources/indexDefinitions/wildcard_index_mappings.json new file mode 100644 index 0000000000..b9974e9548 --- /dev/null +++ b/integ-test/src/test/resources/indexDefinitions/wildcard_index_mappings.json @@ -0,0 +1,21 @@ +{ + "mappings" : { + "properties" : { + "KeywordBody" : { + "type" : "keyword" + }, + "TextKeywordBody" : { + "type" : "text", + "fields" : { + "keyword" : { + "type" : "keyword", + "ignore_above":256 + } + } + }, + "TextBody" : { + "type" : "text" + } + } + } +} diff --git a/integ-test/src/test/resources/wildcard.json b/integ-test/src/test/resources/wildcard.json new file mode 100644 index 0000000000..b25772a47e --- /dev/null +++ b/integ-test/src/test/resources/wildcard.json @@ -0,0 +1,20 @@ +{"index":{"_id":"0"}} +{"KeywordBody":"test wildcard", "TextKeywordBody":"test wildcard", "TextBody":"test wildcard"} +{"index":{"_id":"1"}} +{"KeywordBody":"test wildcard in the end of the text%", "TextKeywordBody":"test wildcard in the end of the text%", "TextBody":"test wildcard in the end of the text%"} +{"index":{"_id":"2"}} +{"KeywordBody":"%test wildcard in the beginning of the text", "TextKeywordBody":"%test wildcard in the beginning of the text", "TextBody":"%test wildcard in the beginning of the text"} +{"index":{"_id":"3"}} +{"KeywordBody":"test wildcard in % the middle of the text", "TextKeywordBody":"test wildcard in % the middle of the text", "TextBody":"test wildcard in % the middle of the text"} +{"index":{"_id":"4"}} +{"KeywordBody":"test wildcard %% beside each other", "TextKeywordBody":"test wildcard %% beside each other", "TextBody":"test wildcard %% beside each other"} +{"index":{"_id":"5"}} +{"KeywordBody":"test wildcard in the end of the text_", "TextKeywordBody":"test wildcard in the end of the text_", "TextBody":"test wildcard in the end of the text_"} +{"index":{"_id":"6"}} +{"KeywordBody":"_test wildcard in the beginning of the text", "TextKeywordBody":"_test wildcard in the beginning of the text", "TextBody":"_test wildcard in the beginning of the text"} +{"index":{"_id":"7"}} +{"KeywordBody":"test wildcard in _ the middle of the text", "TextKeywordBody":"test wildcard in _ the middle of the text", "TextBody":"test wildcard in _ the middle of the text"} +{"index":{"_id":"8"}} +{"KeywordBody":"test wildcard __ beside each other", "TextKeywordBody":"test wildcard __ beside each other", "TextBody":"test wildcard __ beside each other"} +{"index":{"_id":"9"}} +{"KeywordBody":"test backslash wildcard \\_", "TextKeywordBody":"test backslash wildcard \\_", "TextBody":"test backslash wildcard \\_"} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/StringUtils.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/StringUtils.java new file mode 100644 index 0000000000..7b68bd5c92 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/StringUtils.java @@ -0,0 +1,54 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.opensearch.storage.script; + +import lombok.experimental.UtilityClass; + +@UtilityClass +public class StringUtils { + /** + * Converts sql wildcard character % and _ to * and ?. + * @param text string to be converted + * @return converted string + */ + public static String convertSqlWildcardToLucene(String text) { + final char DEFAULT_ESCAPE = '\\'; + StringBuilder convertedString = new StringBuilder(text.length()); + boolean escaped = false; + + for (char currentChar : text.toCharArray()) { + switch (currentChar) { + case DEFAULT_ESCAPE: + escaped = true; + convertedString.append(currentChar); + break; + case '%': + if (escaped) { + convertedString.deleteCharAt(convertedString.length() - 1); + convertedString.append("%"); + } else { + convertedString.append("*"); + } + escaped = false; + break; + case '_': + if (escaped) { + convertedString.deleteCharAt(convertedString.length() - 1); + convertedString.append("_"); + } else { + convertedString.append('?'); + } + escaped = false; + break; + default: + convertedString.append(currentChar); + escaped = false; + } + } + return convertedString.toString(); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java index 2c55a28b88..5f36954d4a 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilder.java @@ -24,11 +24,11 @@ import org.opensearch.sql.expression.FunctionExpression; import org.opensearch.sql.expression.function.BuiltinFunctionName; import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.storage.script.filter.lucene.LikeQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.LuceneQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.RangeQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.RangeQuery.Comparison; import org.opensearch.sql.opensearch.storage.script.filter.lucene.TermQuery; -import org.opensearch.sql.opensearch.storage.script.filter.lucene.WildcardQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchBoolPrefixQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchPhrasePrefixQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.MatchPhraseQuery; @@ -37,6 +37,7 @@ import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.QueryQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.QueryStringQuery; import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.SimpleQueryStringQuery; +import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.WildcardQuery; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @RequiredArgsConstructor @@ -57,7 +58,7 @@ public class FilterQueryBuilder extends ExpressionNodeVisitor b.quoteFieldSuffix(v.stringValue())) .build(); + public static final Map> + WildcardQueryBuildActions = ImmutableMap.>builder() + .put("boost", (b, v) -> b.boost(convertFloatValue(v, "boost"))) + .put("case_insensitive", (b, v) -> b.caseInsensitive(convertBoolValue(v, "case_insensitive"))) + .put("rewrite", (b, v) -> b.rewrite(checkRewrite(v, "rewrite"))) + .build(); + public static final Map ArgumentLimitations = ImmutableMap.builder() .put("boost", "Accepts only floating point values greater than 0.") diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/WildcardQuery.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/WildcardQuery.java new file mode 100644 index 0000000000..9fd37e3de7 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/relevance/WildcardQuery.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + + +package org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance; + +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.WildcardQueryBuilder; +import org.opensearch.sql.opensearch.storage.script.StringUtils; + +/** + * Lucene query that builds wildcard query. + */ +public class WildcardQuery extends SingleFieldQuery { + /** + * Default constructor for WildcardQuery configures how RelevanceQuery.build() handles + * named arguments. + */ + public WildcardQuery() { + super(FunctionParameterRepository.WildcardQueryBuildActions); + } + + @Override + protected String getQueryName() { + return WildcardQueryBuilder.NAME; + } + + @Override + protected WildcardQueryBuilder createBuilder(String field, String query) { + String matchText = StringUtils.convertSqlWildcardToLucene(query); + return QueryBuilders.wildcardQuery(field, matchText); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/StringUtilsTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/StringUtilsTest.java new file mode 100644 index 0000000000..24c809ebab --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/StringUtilsTest.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script; + +import static org.junit.Assert.assertEquals; + +import org.junit.jupiter.api.Test; + +public class StringUtilsTest { + @Test + public void test_escaping_sql_wildcards() { + assertEquals("%", StringUtils.convertSqlWildcardToLucene("\\%")); + assertEquals("\\*", StringUtils.convertSqlWildcardToLucene("\\*")); + assertEquals("_", StringUtils.convertSqlWildcardToLucene("\\_")); + assertEquals("\\?", StringUtils.convertSqlWildcardToLucene("\\?")); + assertEquals("%*", StringUtils.convertSqlWildcardToLucene("\\%%")); + assertEquals("*%", StringUtils.convertSqlWildcardToLucene("%\\%")); + assertEquals("%*%", StringUtils.convertSqlWildcardToLucene("\\%%\\%")); + assertEquals("*%*", StringUtils.convertSqlWildcardToLucene("%\\%%")); + assertEquals("_?", StringUtils.convertSqlWildcardToLucene("\\__")); + assertEquals("?_", StringUtils.convertSqlWildcardToLucene("_\\_")); + assertEquals("_?_", StringUtils.convertSqlWildcardToLucene("\\__\\_")); + assertEquals("?_?", StringUtils.convertSqlWildcardToLucene("_\\__")); + assertEquals("%\\*_\\?", StringUtils.convertSqlWildcardToLucene("\\%\\*\\_\\?")); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java index 737e61f54b..cea4e2488a 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/FilterQueryBuilderTest.java @@ -631,6 +631,130 @@ void should_build_match_phrase_query_with_custom_parameters() { DSL.namedArgument("zero_terms_query", literal("ALL"))))); } + @Test + void wildcard_query_invalid_parameter() { + FunctionExpression expr = DSL.wildcard_query( + DSL.namedArgument("field", literal("field")), + DSL.namedArgument("query", literal("search query*")), + DSL.namedArgument("invalid_parameter", literal("invalid_value"))); + assertThrows(SemanticCheckException.class, () -> buildQuery(expr), + "Parameter invalid_parameter is invalid for wildcard_query function."); + } + + @Test + void wildcard_query_convert_sql_wildcard_to_lucene() { + // Test conversion of % wildcard to * + assertJsonEquals("{\n" + + " \"wildcard\" : {\n" + + " \"field\" : {\n" + + " \"wildcard\" : \"search query*\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(DSL.wildcard_query( + DSL.namedArgument("field", literal("field")), + DSL.namedArgument("query", literal("search query%"))))); + + assertJsonEquals("{\n" + + " \"wildcard\" : {\n" + + " \"field\" : {\n" + + " \"wildcard\" : \"search query?\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(DSL.wildcard_query( + DSL.namedArgument("field", literal("field")), + DSL.namedArgument("query", literal("search query_"))))); + } + + @Test + void wildcard_query_escape_wildcards_characters() { + assertJsonEquals("{\n" + + " \"wildcard\" : {\n" + + " \"field\" : {\n" + + " \"wildcard\" : \"search query%\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(DSL.wildcard_query( + DSL.namedArgument("field", literal("field")), + DSL.namedArgument("query", literal("search query\\%"))))); + + assertJsonEquals("{\n" + + " \"wildcard\" : {\n" + + " \"field\" : {\n" + + " \"wildcard\" : \"search query_\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(DSL.wildcard_query( + DSL.namedArgument("field", literal("field")), + DSL.namedArgument("query", literal("search query\\_"))))); + + assertJsonEquals("{\n" + + " \"wildcard\" : {\n" + + " \"field\" : {\n" + + " \"wildcard\" : \"search query\\\\*\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(DSL.wildcard_query( + DSL.namedArgument("field", literal("field")), + DSL.namedArgument("query", literal("search query\\*"))))); + + assertJsonEquals("{\n" + + " \"wildcard\" : {\n" + + " \"field\" : {\n" + + " \"wildcard\" : \"search query\\\\?\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(DSL.wildcard_query( + DSL.namedArgument("field", literal("field")), + DSL.namedArgument("query", literal("search query\\?"))))); + } + + @Test + void should_build_wildcard_query_with_default_parameters() { + assertJsonEquals("{\n" + + " \"wildcard\" : {\n" + + " \"field\" : {\n" + + " \"wildcard\" : \"search query*\",\n" + + " \"boost\" : 1.0\n" + + " }\n" + + " }\n" + + "}", + buildQuery(DSL.wildcard_query( + DSL.namedArgument("field", literal("field")), + DSL.namedArgument("query", literal("search query*"))))); + } + + @Test + void should_build_wildcard_query_query_with_custom_parameters() { + assertJsonEquals("{\n" + + " \"wildcard\" : {\n" + + " \"field\" : {\n" + + " \"wildcard\" : \"search query*\",\n" + + " \"boost\" : 0.6,\n" + + " \"case_insensitive\" : true,\n" + + " \"rewrite\" : \"constant_score_boolean\"\n" + + " }\n" + + " }\n" + + "}", + buildQuery(DSL.wildcard_query( + DSL.namedArgument("field", literal("field")), + DSL.namedArgument("query", literal("search query*")), + DSL.namedArgument("boost", literal("0.6")), + DSL.namedArgument("case_insensitive", literal("true")), + DSL.namedArgument("rewrite", literal("constant_score_boolean"))))); + } + @Test void query_invalid_parameter() { FunctionExpression expr = DSL.query( diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/WildcardQueryTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/WildcardQueryTest.java new file mode 100644 index 0000000000..ce7a39d91a --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/filter/lucene/WildcardQueryTest.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage.script.filter.lucene; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.opensearch.sql.expression.DSL.namedArgument; + +import java.util.List; +import java.util.stream.Stream; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.opensearch.storage.script.filter.lucene.relevance.WildcardQuery; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +class WildcardQueryTest { + private final WildcardQuery wildcardQueryQuery = new WildcardQuery(); + private static final FunctionName wildcardQueryFunc = FunctionName.of("wildcard_query"); + + static Stream> generateValidData() { + return Stream.of( + List.of( + namedArgument("field", "title"), + namedArgument("query", "query_value*"), + namedArgument("boost", "0.7"), + namedArgument("case_insensitive", "false"), + namedArgument("rewrite", "constant_score_boolean") + ) + ); + } + + @ParameterizedTest + @MethodSource("generateValidData") + public void test_valid_parameters(List validArgs) { + Assertions.assertNotNull(wildcardQueryQuery.build( + new WildcardQueryExpression(validArgs))); + } + + @Test + public void test_SyntaxCheckException_when_no_arguments() { + List arguments = List.of(); + assertThrows(SyntaxCheckException.class, + () -> wildcardQueryQuery.build(new WildcardQueryExpression(arguments))); + } + + @Test + public void test_SyntaxCheckException_when_one_argument() { + List arguments = List.of(namedArgument("field", "title")); + assertThrows(SyntaxCheckException.class, + () -> wildcardQueryQuery.build(new WildcardQueryExpression(arguments))); + } + + @Test + public void test_SemanticCheckException_when_invalid_parameter() { + List arguments = List.of( + namedArgument("field", "title"), + namedArgument("query", "query_value*"), + namedArgument("unsupported", "unsupported_value")); + Assertions.assertThrows(SemanticCheckException.class, + () -> wildcardQueryQuery.build(new WildcardQueryExpression(arguments))); + } + + private class WildcardQueryExpression extends FunctionExpression { + public WildcardQueryExpression(List arguments) { + super(WildcardQueryTest.this.wildcardQueryFunc, arguments); + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new UnsupportedOperationException("Invalid function call, " + + "valueOf function need implementation only to support Expression interface"); + } + + @Override + public ExprType type() { + throw new UnsupportedOperationException("Invalid function call, " + + "type function need implementation only to support Expression interface"); + } + } +} diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index c416c78432..a18aee8f10 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -338,6 +338,7 @@ ANALYZER: 'ANALYZER'; ANALYZE_WILDCARD: 'ANALYZE_WILDCARD'; AUTO_GENERATE_SYNONYMS_PHRASE_QUERY:'AUTO_GENERATE_SYNONYMS_PHRASE_QUERY'; BOOST: 'BOOST'; +CASE_INSENSITIVE: 'CASE_INSENSITIVE'; CUTOFF_FREQUENCY: 'CUTOFF_FREQUENCY'; DEFAULT_FIELD: 'DEFAULT_FIELD'; DEFAULT_OPERATOR: 'DEFAULT_OPERATOR'; diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index e6ae551fa2..6e8e0e08fe 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -473,6 +473,7 @@ singleFieldRelevanceFunctionName : MATCH | MATCHQUERY | MATCH_QUERY | MATCH_PHRASE | MATCHPHRASE | MATCHPHRASEQUERY | MATCH_BOOL_PREFIX | MATCH_PHRASE_PREFIX + | WILDCARD_QUERY | WILDCARDQUERY ; multiFieldRelevanceFunctionName @@ -502,7 +503,7 @@ highlightArg relevanceArgName : ALLOW_LEADING_WILDCARD | ANALYZER | ANALYZE_WILDCARD | AUTO_GENERATE_SYNONYMS_PHRASE_QUERY - | BOOST | CUTOFF_FREQUENCY | DEFAULT_FIELD | DEFAULT_OPERATOR | ENABLE_POSITION_INCREMENTS + | BOOST | CASE_INSENSITIVE | CUTOFF_FREQUENCY | DEFAULT_FIELD | DEFAULT_OPERATOR | ENABLE_POSITION_INCREMENTS | ESCAPE | FIELDS | FLAGS | FUZZINESS | FUZZY_MAX_EXPANSIONS | FUZZY_PREFIX_LENGTH | FUZZY_REWRITE | FUZZY_TRANSPOSITIONS | LENIENT | LOW_FREQ_OPERATOR | MAX_DETERMINIZED_STATES | MAX_EXPANSIONS | MINIMUM_SHOULD_MATCH | OPERATOR | PHRASE_SLOP | PREFIX_LENGTH diff --git a/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java b/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java index bf2c9af623..bfd0f93ec9 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/antlr/SQLSyntaxParserTest.java @@ -15,7 +15,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Iterator; -import java.util.List; import java.util.Map; import java.util.Random; import java.util.stream.Stream; @@ -449,6 +448,21 @@ public void can_parse_match_phrase_relevance_function() { assertNotNull(parser.parse("SELECT * FROM test WHERE match_phrase(column, 100500)")); } + @Test + public void can_parse_wildcard_query_relevance_function() { + assertNotNull( + parser.parse("SELECT * FROM test WHERE wildcard_query(column, \"this is a test*\")")); + assertNotNull( + parser.parse("SELECT * FROM test WHERE wildcard_query(column, 'this is a test*')")); + assertNotNull( + parser.parse("SELECT * FROM test WHERE wildcard_query(`column`, \"this is a test*\")")); + assertNotNull( + parser.parse("SELECT * FROM test WHERE wildcard_query(`column`, 'this is a test*')")); + assertNotNull( + parser.parse("SELECT * FROM test WHERE wildcard_query(`column`, 'this is a test*', " + + "boost=1.5, case_insensitive=true, rewrite=\"scoring_boolean\")")); + } + @ParameterizedTest @MethodSource({ "matchPhraseComplexQueries", diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index ac68d146b2..9af4119fdf 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -637,6 +637,19 @@ public void relevanceQuery_string() { + "analyzer='keyword', time_zone='Canada/Pacific', tie_breaker='1.3')")); } + @Test + public void relevanceWildcard_query() { + assertEquals(AstDSL.function("wildcard_query", + unresolvedArg("field", stringLiteral("field")), + unresolvedArg("query", stringLiteral("search query*")), + unresolvedArg("boost", stringLiteral("1.5")), + unresolvedArg("case_insensitive", stringLiteral("true")), + unresolvedArg("rewrite", stringLiteral("scoring_boolean"))), + buildExprAst("wildcard_query(field, 'search query*', boost=1.5," + + "case_insensitive=true, rewrite='scoring_boolean'))") + ); + } + @Test public void relevanceQuery() { assertEquals(AstDSL.function("query", From c923e80cd654ee8136c74180bf0bd6231044ff71 Mon Sep 17 00:00:00 2001 From: Margarit Hakobyan Date: Fri, 9 Dec 2022 10:24:22 -0800 Subject: [PATCH 4/5] Add reverse() string function to V2 SQL Engine(#1154) * Add reverse() string function to V2 SQL engine Signed-off-by: Margarit Hakobyan --- .../org/opensearch/sql/expression/DSL.java | 4 ++++ .../function/BuiltinFunctionName.java | 1 + .../sql/expression/text/TextFunction.java | 16 +++++++++++++ .../sql/expression/text/TextFunctionTest.java | 12 ++++++++++ docs/user/dql/functions.rst | 23 +++++++++++++++++++ docs/user/ppl/functions/string.rst | 23 +++++++++++++++++++ .../opensearch/sql/ppl/TextFunctionIT.java | 5 ++++ .../opensearch/sql/sql/TextFunctionIT.java | 14 +++++++++++ ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 1 + ppl/src/main/antlr/OpenSearchPPLParser.g4 | 2 +- sql/src/main/antlr/OpenSearchSQLLexer.g4 | 1 + sql/src/main/antlr/OpenSearchSQLParser.g4 | 2 +- 12 files changed, 102 insertions(+), 2 deletions(-) diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 3b601f98a3..fc425c6c20 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -486,6 +486,10 @@ public static FunctionExpression replace(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.REPLACE, expressions); } + public static FunctionExpression reverse(Expression... expressions) { + return compile(FunctionProperties.None, BuiltinFunctionName.REVERSE, expressions); + } + public static FunctionExpression and(Expression... expressions) { return compile(FunctionProperties.None, BuiltinFunctionName.AND, expressions); } diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index 0b7701d8a9..b23c7613d6 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -167,6 +167,7 @@ public enum BuiltinFunctionName { POSITION(FunctionName.of("position")), REGEXP(FunctionName.of("regexp")), REPLACE(FunctionName.of("replace")), + REVERSE(FunctionName.of("reverse")), RIGHT(FunctionName.of("right")), RTRIM(FunctionName.of("rtrim")), STRCMP(FunctionName.of("strcmp")), diff --git a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java index 5915700bf1..25eb25489c 100644 --- a/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/text/TextFunction.java @@ -49,6 +49,7 @@ public void register(BuiltinFunctionRepository repository) { repository.register(ltrim()); repository.register(position()); repository.register(replace()); + repository.register(reverse()); repository.register(right()); repository.register(rtrim()); repository.register(strcmp()); @@ -268,6 +269,17 @@ private DefaultFunctionResolver replace() { impl(nullMissingHandling(TextFunction::exprReplace), STRING, STRING, STRING, STRING)); } + /** + * REVERSE(str) returns reversed string of the string supplied as an argument + * Returns NULL if the argument is NULL. + * Supports the following signature: + * (STRING) -> STRING + */ + private DefaultFunctionResolver reverse() { + return define(BuiltinFunctionName.REVERSE.getName(), + impl(nullMissingHandling(TextFunction::exprReverse), STRING, STRING)); + } + private static ExprValue exprSubstrStart(ExprValue exprValue, ExprValue start) { int startIdx = start.integerValue(); if (startIdx == 0) { @@ -331,5 +343,9 @@ private static ExprValue exprLocate(ExprValue subStr, ExprValue str, ExprValue p private static ExprValue exprReplace(ExprValue str, ExprValue from, ExprValue to) { return new ExprStringValue(str.stringValue().replaceAll(from.stringValue(), to.stringValue())); } + + private static ExprValue exprReverse(ExprValue str) { + return new ExprStringValue(new StringBuilder(str.stringValue()).reverse().toString()); + } } diff --git a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java index 5e32678b94..515b436c82 100644 --- a/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java +++ b/core/src/test/java/org/opensearch/sql/expression/text/TextFunctionTest.java @@ -412,6 +412,18 @@ void replace() { assertEquals(missingValue(), eval(DSL.replace(missingRef, DSL.literal("a"), DSL.literal("b")))); } + @Test + void reverse() { + FunctionExpression expression = DSL.reverse(DSL.literal("abcde")); + assertEquals(STRING, expression.type()); + assertEquals("edcba", eval(expression).stringValue()); + + when(nullRef.type()).thenReturn(STRING); + assertEquals(nullValue(), eval(DSL.reverse(nullRef))); + when(missingRef.type()).thenReturn(STRING); + assertEquals(missingValue(), eval(DSL.reverse(missingRef))); + } + void testConcatString(List strings) { String expected = null; if (strings.stream().noneMatch(Objects::isNull)) { diff --git a/docs/user/dql/functions.rst b/docs/user/dql/functions.rst index 7be50ccffb..843d6c7e45 100644 --- a/docs/user/dql/functions.rst +++ b/docs/user/dql/functions.rst @@ -2468,6 +2468,29 @@ Example:: +--------------------------------------------------+ +REVERSE +------- + +Description +>>>>>>>>>>> + +Usage: REVERSE(str) returns reversed string of the string supplied as an argument. Returns NULL if the argument is NULL. + +Argument type: STRING + +Return type: STRING + +Example:: + + os> SELECT REVERSE('abcde'), REVERSE(null) + fetched rows / total rows = 1/1 + +--------------------+-----------------+ + | REVERSE('abcde') | REVERSE(null) | + |--------------------+-----------------| + | edcba | null | + +--------------------+-----------------+ + + RIGHT ----- diff --git a/docs/user/ppl/functions/string.rst b/docs/user/ppl/functions/string.rst index b14acc88e0..361bc2ef37 100644 --- a/docs/user/ppl/functions/string.rst +++ b/docs/user/ppl/functions/string.rst @@ -175,6 +175,29 @@ Example:: +-------------------------------------+---------------------------------------+ +REVERSE +----- + +Description +>>>>>>>>>>> + +Usage: REVERSE(str) returns reversed string of the string supplied as an argument. + +Argument type: STRING + +Return type: STRING + +Example:: + + os> source=people | eval `REVERSE('abcde')` = REVERSE('abcde') | fields `REVERSE('abcde')` + fetched rows / total rows = 1/1 + +--------------------+ + | REVERSE('abcde') | + |--------------------| + | edcba | + +--------------------+ + + RIGHT ----- diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java index 84717900ca..7c48bceab0 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TextFunctionIT.java @@ -139,4 +139,9 @@ public void testLocate() throws IOException { public void testReplace() throws IOException { verifyQuery("replace", "", ", 'world', ' opensearch'", "hello", " opensearch", "hello opensearch"); } + + @Test + public void testReverse() throws IOException { + verifyQuery("reverse", "", "", "olleh", "dlrow", "dlrowolleh"); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java index c907b36a63..175cafd31e 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/TextFunctionIT.java @@ -41,12 +41,26 @@ void verifyQuery(String query, String type, Integer output) throws IOException { verifyDataRows(result, rows(output)); } + void verifyQueryWithNullOutput(String query, String type) throws IOException { + JSONObject result = executeQuery("select 'test null'," + query); + verifySchema(result, schema(query, null, type), + schema("'test null'", null, type)); + verifyDataRows(result, rows("test null", null)); + } + @Test public void testRegexp() throws IOException { verifyQuery("'a' regexp 'b'", "integer", 0); verifyQuery("'a' regexp '.*'", "integer", 1); } + @Test + public void testReverse() throws IOException { + verifyQuery("reverse('hello')", "keyword", "olleh"); + verifyQuery("reverse('')", "keyword", ""); + verifyQueryWithNullOutput("reverse(null)", "keyword"); + } + @Test public void testSubstr() throws IOException { verifyQuery("substr('hello', 2)", "keyword", "ello"); diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 8c0340e7f1..9282c42308 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -290,6 +290,7 @@ LEFT: 'LEFT'; ASCII: 'ASCII'; LOCATE: 'LOCATE'; REPLACE: 'REPLACE'; +REVERSE: 'REVERSE'; CAST: 'CAST'; // BOOL FUNCTIONS diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 6dba1ae783..a0d6553875 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -487,7 +487,7 @@ systemFunctionBase textFunctionBase : SUBSTR | SUBSTRING | TRIM | LTRIM | RTRIM | LOWER | UPPER | CONCAT | CONCAT_WS | LENGTH | STRCMP - | RIGHT | LEFT | ASCII | LOCATE | REPLACE + | RIGHT | LEFT | ASCII | LOCATE | REPLACE | REVERSE ; positionFunctionName diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index a18aee8f10..a359f48be3 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -243,6 +243,7 @@ REPLACE: 'REPLACE'; RINT: 'RINT'; ROUND: 'ROUND'; RTRIM: 'RTRIM'; +REVERSE: 'REVERSE'; SIGN: 'SIGN'; SIGNUM: 'SIGNUM'; SIN: 'SIN'; diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index 6e8e0e08fe..58d4be1813 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -454,7 +454,7 @@ dateTimeFunctionName textFunctionName : SUBSTR | SUBSTRING | TRIM | LTRIM | RTRIM | LOWER | UPPER | CONCAT | CONCAT_WS | SUBSTR | LENGTH | STRCMP | RIGHT | LEFT - | ASCII | LOCATE | REPLACE + | ASCII | LOCATE | REPLACE | REVERSE ; flowControlFunctionName From c5e8fc09b40e8440389e6c91d67144d48560eba2 Mon Sep 17 00:00:00 2001 From: Andriy Redko Date: Fri, 9 Dec 2022 18:13:19 -0500 Subject: [PATCH 5/5] Update Jackson to 2.14.1 and fix dependency resolution issues (#1150) Signed-off-by: Andriy Redko Signed-off-by: Andriy Redko --- build.gradle | 4 ++-- integ-test/build.gradle | 2 ++ plugin/build.gradle | 2 ++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/build.gradle b/build.gradle index a9a171da01..b6f64560db 100644 --- a/build.gradle +++ b/build.gradle @@ -8,8 +8,8 @@ buildscript { ext { opensearch_version = System.getProperty("opensearch.version", "2.4.0-SNAPSHOT") spring_version = "5.3.22" - jackson_version = "2.14.0" - jackson_databind_version = "2.14.0" + jackson_version = "2.14.1" + jackson_databind_version = "2.14.1" isSnapshot = "true" == System.getProperty("build.snapshot", "true") buildVersionQualifier = System.getProperty("build.version_qualifier", "") version_tokens = opensearch_version.tokenize('-') diff --git a/integ-test/build.gradle b/integ-test/build.gradle index 6c10d55262..8adfd84b40 100644 --- a/integ-test/build.gradle +++ b/integ-test/build.gradle @@ -66,6 +66,8 @@ configurations.all { resolutionStrategy.force 'commons-codec:commons-codec:1.13' resolutionStrategy.force 'com.google.guava:guava:31.0.1-jre' resolutionStrategy.force "com.fasterxml.jackson.core:jackson-core:${jackson_version}" + resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:${jackson_version}" + resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-smile:${jackson_version}" resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${jackson_version}" resolutionStrategy.force "com.fasterxml.jackson.core:jackson-databind:${jackson_databind_version}" resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib:1.6.0" diff --git a/plugin/build.gradle b/plugin/build.gradle index d2bdb87275..d689da0f57 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -89,6 +89,8 @@ configurations.all { // enforce 1.1.3, https://www.whitesourcesoftware.com/vulnerability-database/WS-2019-0379 resolutionStrategy.force 'commons-codec:commons-codec:1.13' resolutionStrategy.force 'com.google.guava:guava:31.0.1-jre' + resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-yaml:${jackson_version}" + resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-smile:${jackson_version}" resolutionStrategy.force "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${jackson_version}" resolutionStrategy.force "com.fasterxml.jackson.core:jackson-databind:${jackson_databind_version}" resolutionStrategy.force "org.jetbrains.kotlin:kotlin-stdlib:1.6.0"