From 3e2cb1dd907d30ac7dc3e1058078adf1295bb6fe Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Mon, 9 Dec 2024 14:37:01 -0800 Subject: [PATCH 1/2] Fix FilterOperator to cache next element and avoid repeated consumption on hasNext() calls (#3123) Signed-off-by: Peng Huo --- .../sql/planner/physical/FilterOperator.java | 32 +++++-- .../planner/physical/FilterOperatorTest.java | 84 +++++++++++++++++++ .../resources/correctness/bugfixes/3121.txt | 1 + 3 files changed, 109 insertions(+), 8 deletions(-) create mode 100644 integ-test/src/test/resources/correctness/bugfixes/3121.txt diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java index 192ea5cb4f..088dd07f8d 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/FilterOperator.java @@ -28,6 +28,7 @@ public class FilterOperator extends PhysicalPlan { @Getter private final PhysicalPlan input; @Getter private final Expression conditions; @ToString.Exclude private ExprValue next = null; + @ToString.Exclude private boolean nextPrepared = false; @Override public R accept(PhysicalPlanNodeVisitor visitor, C context) { @@ -41,19 +42,34 @@ public List getChild() { @Override public boolean hasNext() { + if (!nextPrepared) { + prepareNext(); + } + return next != null; + } + + @Override + public ExprValue next() { + if (!nextPrepared) { + prepareNext(); + } + ExprValue result = next; + next = null; + nextPrepared = false; + return result; + } + + private void prepareNext() { while (input.hasNext()) { ExprValue inputValue = input.next(); ExprValue exprValue = conditions.valueOf(inputValue.bindingTuples()); - if (!(exprValue.isNull() || exprValue.isMissing()) && (exprValue.booleanValue())) { + if (!(exprValue.isNull() || exprValue.isMissing()) && exprValue.booleanValue()) { next = inputValue; - return true; + nextPrepared = true; + return; } } - return false; - } - - @Override - public ExprValue next() { - return next; + next = null; + nextPrepared = true; } } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java index bfe3b323c4..ba2354b168 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/FilterOperatorTest.java @@ -8,14 +8,24 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_FALSE; import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_MISSING; import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_NULL; +import static org.opensearch.sql.data.model.ExprValueUtils.LITERAL_TRUE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; +import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.filter; import com.google.common.collect.ImmutableMap; import java.util.LinkedHashMap; import java.util.List; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; @@ -26,12 +36,22 @@ import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.data.model.ExprValueUtils; import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; @ExtendWith(MockitoExtension.class) @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) class FilterOperatorTest extends PhysicalPlanTestBase { @Mock private PhysicalPlan inputPlan; + @Mock private Expression condition; + + private FilterOperator filterOperator; + + @BeforeEach + public void setup() { + filterOperator = filter(inputPlan, condition); + } + @Test public void filter_test() { FilterOperator plan = @@ -82,4 +102,68 @@ public void missing_value_should_been_ignored() { List result = execute(plan); assertEquals(0, result.size()); } + + @Test + public void testHasNextWhenInputHasNoElements() { + when(inputPlan.hasNext()).thenReturn(false); + + assertFalse( + filterOperator.hasNext(), "hasNext() should return false when input has no elements"); + } + + @Test + public void testHasNextWithMatchingCondition() { + ExprValue inputValue = mock(ExprValue.class); + when(inputPlan.hasNext()).thenReturn(true).thenReturn(false); + when(inputPlan.next()).thenReturn(inputValue); + when(condition.valueOf(any())).thenReturn(LITERAL_TRUE); + + assertTrue(filterOperator.hasNext(), "hasNext() should return true when condition matches"); + assertEquals( + inputValue, filterOperator.next(), "next() should return the matching input value"); + } + + @Test + public void testHasNextWithNonMatchingCondition() { + ExprValue inputValue = mock(ExprValue.class); + when(inputPlan.hasNext()).thenReturn(true, false); + when(inputPlan.next()).thenReturn(inputValue); + when(condition.valueOf(any())).thenReturn(LITERAL_FALSE); + + assertFalse( + filterOperator.hasNext(), "hasNext() should return false if no values match the condition"); + } + + @Test + public void testMultipleCallsToHasNextDoNotConsumeInput() { + ExprValue inputValue = mock(ExprValue.class); + when(inputPlan.hasNext()).thenReturn(true); + when(inputPlan.next()).thenReturn(inputValue); + when(condition.valueOf(any())).thenReturn(LITERAL_TRUE); + + assertTrue( + filterOperator.hasNext(), + "First hasNext() call should return true if there is a matching value"); + verify(inputPlan, times(1)).next(); + assertTrue( + filterOperator.hasNext(), + "Subsequent hasNext() calls should still return true without advancing the input"); + verify(inputPlan, times(1)).next(); + assertEquals( + inputValue, filterOperator.next(), "next() should return the matching input value"); + verify(inputPlan, times(1)).next(); + } + + @Test + public void testNextWithoutCallingHasNext() { + ExprValue inputValue = mock(ExprValue.class); + when(inputPlan.hasNext()).thenReturn(true, false); + when(inputPlan.next()).thenReturn(inputValue); + when(condition.valueOf(any())).thenReturn(LITERAL_TRUE); + + assertEquals( + inputValue, + filterOperator.next(), + "next() should return the matching input value even if hasNext() was not called"); + } } diff --git a/integ-test/src/test/resources/correctness/bugfixes/3121.txt b/integ-test/src/test/resources/correctness/bugfixes/3121.txt new file mode 100644 index 0000000000..f60f724897 --- /dev/null +++ b/integ-test/src/test/resources/correctness/bugfixes/3121.txt @@ -0,0 +1 @@ +SELECT Origin, Dest FROM (SELECT * FROM opensearch_dashboards_sample_data_flights WHERE AvgTicketPrice > 100 GROUP BY Origin, Dest, AvgTicketPrice) AS flights WHERE AvgTicketPrice < 1000 ORDER BY AvgTicketPrice LIMIT 30 From ed0ca8ddf6a217ebafd24dbb3263db798bf15e27 Mon Sep 17 00:00:00 2001 From: James Duong Date: Thu, 12 Dec 2024 11:01:17 -0800 Subject: [PATCH 2/2] Add trendline PPL command (#3071) * Add trendline (With SWA) PPL command --------- Signed-off-by: James Duong Signed-off-by: Andrew Carbonetto Co-authored-by: Andrew Carbonetto --- .../org/opensearch/sql/analysis/Analyzer.java | 94 ++++- .../sql/ast/AbstractNodeVisitor.java | 9 + .../org/opensearch/sql/ast/dsl/AstDSL.java | 14 + .../opensearch/sql/ast/tree/Trendline.java | 71 ++++ .../org/opensearch/sql/executor/Explain.java | 32 ++ .../sql/planner/DefaultImplementor.java | 7 + .../sql/planner/logical/LogicalPlanDSL.java | 7 + .../logical/LogicalPlanNodeVisitor.java | 4 + .../sql/planner/logical/LogicalTrendline.java | 42 ++ .../physical/PhysicalPlanNodeVisitor.java | 4 + .../planner/physical/TrendlineOperator.java | 317 ++++++++++++++ .../opensearch/sql/analysis/AnalyzerTest.java | 60 +++ .../opensearch/sql/executor/ExplainTest.java | 43 ++ .../sql/planner/DefaultImplementorTest.java | 21 + .../logical/LogicalPlanNodeVisitorTest.java | 14 +- .../physical/PhysicalPlanNodeVisitorTest.java | 31 +- .../physical/TrendlineOperatorTest.java | 398 ++++++++++++++++++ docs/category.json | 1 + docs/user/ppl/cmd/trendline.rst | 90 ++++ docs/user/ppl/index.rst | 2 + .../org/opensearch/sql/ppl/ExplainIT.java | 26 ++ .../sql/ppl/TrendlineCommandIT.java | 78 ++++ .../ppl/explain_trendline_push.json | 32 ++ .../ppl/explain_trendline_sort_push.json | 32 ++ .../OpenSearchExecutionProtector.java | 7 + .../OpenSearchExecutionProtectorTest.java | 21 + ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 4 + ppl/src/main/antlr/OpenSearchPPLParser.g4 | 14 + .../opensearch/sql/ppl/parser/AstBuilder.java | 16 + .../sql/ppl/parser/AstExpressionBuilder.java | 25 ++ .../sql/ppl/utils/PPLQueryDataAnonymizer.java | 27 +- .../sql/ppl/parser/AstBuilderTest.java | 74 ++++ .../ppl/utils/PPLQueryDataAnonymizerTest.java | 7 + 33 files changed, 1601 insertions(+), 23 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java create mode 100644 core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java create mode 100644 core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java create mode 100644 docs/user/ppl/cmd/trendline.rst create mode 100644 integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java create mode 100644 integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_push.json create mode 100644 integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_sort_push.json diff --git a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java index 71db736f78..d0051568c4 100644 --- a/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java +++ b/core/src/main/java/org/opensearch/sql/analysis/Analyzer.java @@ -10,7 +10,10 @@ import static org.opensearch.sql.ast.tree.Sort.NullOrder.NULL_LAST; import static org.opensearch.sql.ast.tree.Sort.SortOrder.ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder.DESC; +import static org.opensearch.sql.data.type.ExprCoreType.DATE; import static org.opensearch.sql.data.type.ExprCoreType.STRUCT; +import static org.opensearch.sql.data.type.ExprCoreType.TIME; +import static org.opensearch.sql.data.type.ExprCoreType.TIMESTAMP; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALOUS; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_ANOMALY_GRADE; import static org.opensearch.sql.utils.MLCommonsConstants.RCF_SCORE; @@ -22,6 +25,7 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.Optional; @@ -62,6 +66,7 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; import org.opensearch.sql.common.antlr.SyntaxCheckException; @@ -100,6 +105,7 @@ import org.opensearch.sql.planner.logical.LogicalRemove; import org.opensearch.sql.planner.logical.LogicalRename; import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.planner.logical.LogicalTrendline; import org.opensearch.sql.planner.logical.LogicalValues; import org.opensearch.sql.planner.physical.datasource.DataSourceTable; import org.opensearch.sql.storage.Table; @@ -469,23 +475,7 @@ public LogicalPlan visitParse(Parse node, AnalysisContext context) { @Override public LogicalPlan visitSort(Sort node, AnalysisContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); - ExpressionReferenceOptimizer optimizer = - new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); - - List> sortList = - node.getSortList().stream() - .map( - sortField -> { - var analyzed = expressionAnalyzer.analyze(sortField.getField(), context); - if (analyzed == null) { - throw new UnsupportedOperationException( - String.format("Invalid use of expression %s", sortField.getField())); - } - Expression expression = optimizer.optimize(analyzed, context); - return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression); - }) - .collect(Collectors.toList()); - return new LogicalSort(child, sortList); + return buildSort(child, context, node.getSortList()); } /** Build {@link LogicalDedupe}. */ @@ -594,6 +584,55 @@ public LogicalPlan visitML(ML node, AnalysisContext context) { return new LogicalML(child, node.getArguments()); } + /** Build {@link LogicalTrendline} for Trendline command. */ + @Override + public LogicalPlan visitTrendline(Trendline node, AnalysisContext context) { + final LogicalPlan child = node.getChild().get(0).accept(this, context); + + final TypeEnvironment currEnv = context.peek(); + final List computations = node.getComputations(); + final ImmutableList.Builder> + computationsAndTypes = ImmutableList.builder(); + computations.forEach( + computation -> { + final Expression resolvedField = + expressionAnalyzer.analyze(computation.getDataField(), context); + final ExprCoreType averageType; + // Duplicate the semantics of AvgAggregator#create(): + // - All numerical types have the DOUBLE type for the moving average. + // - All datetime types have the same datetime type for the moving average. + if (ExprCoreType.numberTypes().contains(resolvedField.type())) { + averageType = ExprCoreType.DOUBLE; + } else { + switch (resolvedField.type()) { + case DATE: + case TIME: + case TIMESTAMP: + averageType = (ExprCoreType) resolvedField.type(); + break; + default: + throw new SemanticCheckException( + String.format( + "Invalid field used for trendline computation %s. Source field %s had type" + + " %s but must be a numerical or datetime field.", + computation.getAlias(), + computation.getDataField().getChild().get(0), + resolvedField.type().typeName())); + } + } + currEnv.define(new Symbol(Namespace.FIELD_NAME, computation.getAlias()), averageType); + computationsAndTypes.add(Pair.of(computation, averageType)); + }); + + if (node.getSortByField().isEmpty()) { + return new LogicalTrendline(child, computationsAndTypes.build()); + } + + return new LogicalTrendline( + buildSort(child, context, Collections.singletonList(node.getSortByField().get())), + computationsAndTypes.build()); + } + @Override public LogicalPlan visitPaginate(Paginate paginate, AnalysisContext context) { LogicalPlan child = paginate.getChild().get(0).accept(this, context); @@ -612,6 +651,27 @@ public LogicalPlan visitCloseCursor(CloseCursor closeCursor, AnalysisContext con return new LogicalCloseCursor(closeCursor.getChild().get(0).accept(this, context)); } + private LogicalSort buildSort( + LogicalPlan child, AnalysisContext context, List sortFields) { + ExpressionReferenceOptimizer optimizer = + new ExpressionReferenceOptimizer(expressionAnalyzer.getRepository(), child); + + List> sortList = + sortFields.stream() + .map( + sortField -> { + var analyzed = expressionAnalyzer.analyze(sortField.getField(), context); + if (analyzed == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", sortField.getField())); + } + Expression expression = optimizer.optimize(analyzed, context); + return ImmutablePair.of(analyzeSortOption(sortField.getFieldArgs()), expression); + }) + .collect(Collectors.toList()); + return new LogicalSort(child, sortList); + } + /** * The first argument is always "asc", others are optional. Given nullFirst argument, use its * value. Otherwise just use DEFAULT_ASC/DESC. diff --git a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index a0520dc70e..f27260dd5f 100644 --- a/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -60,6 +60,7 @@ import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.Values; /** AST nodes visitor Defines the traverse path. */ @@ -110,6 +111,14 @@ public T visitFilter(Filter node, C context) { return visitChildren(node, context); } + public T visitTrendline(Trendline node, C context) { + return visitChildren(node, context); + } + + public T visitTrendlineComputation(Trendline.TrendlineComputation node, C context) { + return visitChildren(node, context); + } + public T visitProject(Project node, C context) { return visitChildren(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java index 8135731ff6..d9956609ec 100644 --- a/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java +++ b/core/src/main/java/org/opensearch/sql/ast/dsl/AstDSL.java @@ -8,6 +8,7 @@ import com.google.common.collect.ImmutableList; import java.util.Arrays; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; import lombok.experimental.UtilityClass; import org.apache.commons.lang3.tuple.ImmutablePair; @@ -62,6 +63,7 @@ import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.ast.tree.Values; @@ -466,6 +468,18 @@ public static Limit limit(UnresolvedPlan input, Integer limit, Integer offset) { return new Limit(limit, offset).attach(input); } + public static Trendline trendline( + UnresolvedPlan input, + Optional sortField, + Trendline.TrendlineComputation... computations) { + return new Trendline(sortField, Arrays.asList(computations)).attach(input); + } + + public static Trendline.TrendlineComputation computation( + Integer numDataPoints, Field dataField, String alias, Trendline.TrendlineType type) { + return new Trendline.TrendlineComputation(numDataPoints, dataField, alias, type); + } + public static Parse parse( UnresolvedPlan input, ParseMethod parseMethod, diff --git a/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java new file mode 100644 index 0000000000..aa4fcc200d --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/ast/tree/Trendline.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.Optional; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import lombok.ToString; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +@ToString +@Getter +@RequiredArgsConstructor +@EqualsAndHashCode(callSuper = false) +public class Trendline extends UnresolvedPlan { + + private UnresolvedPlan child; + private final Optional sortByField; + private final List computations; + + @Override + public Trendline attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return ImmutableList.of(child); + } + + @Override + public T accept(AbstractNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } + + @Getter + public static class TrendlineComputation extends UnresolvedExpression { + + private final Integer numberOfDataPoints; + private final Field dataField; + private final String alias; + private final TrendlineType computationType; + + public TrendlineComputation( + Integer numberOfDataPoints, Field dataField, String alias, TrendlineType computationType) { + this.numberOfDataPoints = numberOfDataPoints; + this.dataField = dataField; + this.alias = alias; + this.computationType = computationType; + } + + @Override + public R accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitTrendlineComputation(this, context); + } + } + + public enum TrendlineType { + SMA + } +} diff --git a/core/src/main/java/org/opensearch/sql/executor/Explain.java b/core/src/main/java/org/opensearch/sql/executor/Explain.java index fffbe6f693..31890a8090 100644 --- a/core/src/main/java/org/opensearch/sql/executor/Explain.java +++ b/core/src/main/java/org/opensearch/sql/executor/Explain.java @@ -8,12 +8,14 @@ import com.google.common.collect.ImmutableMap; import java.util.ArrayList; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.function.Consumer; import java.util.function.Function; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse; import org.opensearch.sql.executor.ExecutionEngine.ExplainResponseNode; import org.opensearch.sql.expression.Expression; @@ -31,6 +33,7 @@ import org.opensearch.sql.planner.physical.RenameOperator; import org.opensearch.sql.planner.physical.SortOperator; import org.opensearch.sql.planner.physical.TakeOrderedOperator; +import org.opensearch.sql.planner.physical.TrendlineOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; import org.opensearch.sql.storage.TableScanOperator; @@ -211,6 +214,21 @@ public ExplainResponseNode visitNested(NestedOperator node, Object context) { explanNode -> explanNode.setDescription(ImmutableMap.of("nested", node.getFields()))); } + @Override + public ExplainResponseNode visitTrendline(TrendlineOperator node, Object context) { + return explain( + node, + context, + explainNode -> + explainNode.setDescription( + ImmutableMap.of( + "computations", + describeTrendlineComputations( + node.getComputations().stream() + .map(Pair::getKey) + .collect(Collectors.toList()))))); + } + protected ExplainResponseNode explain( PhysicalPlan node, Object context, Consumer doExplain) { ExplainResponseNode explainNode = new ExplainResponseNode(getOperatorName(node)); @@ -245,4 +263,18 @@ private Map> describeSortList( "sortOrder", p.getLeft().getSortOrder().toString(), "nullOrder", p.getLeft().getNullOrder().toString()))); } + + private List> describeTrendlineComputations( + List computations) { + return computations.stream() + .map( + computation -> + ImmutableMap.of( + "computationType", + computation.getComputationType().name().toLowerCase(Locale.ROOT), + "numberOfDataPoints", computation.getNumberOfDataPoints().toString(), + "dataField", computation.getDataField().getChild().get(0).toString(), + "alias", computation.getAlias())) + .collect(Collectors.toList()); + } } diff --git a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java index f962c3e4bf..c988084d1b 100644 --- a/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java +++ b/core/src/main/java/org/opensearch/sql/planner/DefaultImplementor.java @@ -23,6 +23,7 @@ import org.opensearch.sql.planner.logical.LogicalRemove; import org.opensearch.sql.planner.logical.LogicalRename; import org.opensearch.sql.planner.logical.LogicalSort; +import org.opensearch.sql.planner.logical.LogicalTrendline; import org.opensearch.sql.planner.logical.LogicalValues; import org.opensearch.sql.planner.logical.LogicalWindow; import org.opensearch.sql.planner.physical.AggregationOperator; @@ -39,6 +40,7 @@ import org.opensearch.sql.planner.physical.RenameOperator; import org.opensearch.sql.planner.physical.SortOperator; import org.opensearch.sql.planner.physical.TakeOrderedOperator; +import org.opensearch.sql.planner.physical.TrendlineOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; import org.opensearch.sql.storage.read.TableScanBuilder; @@ -166,6 +168,11 @@ public PhysicalPlan visitCloseCursor(LogicalCloseCursor node, C context) { return new CursorCloseOperator(visitChild(node, context)); } + @Override + public PhysicalPlan visitTrendline(LogicalTrendline plan, C context) { + return new TrendlineOperator(visitChild(plan, context), plan.getComputations()); + } + // Called when paging query requested without `FROM` clause only @Override public PhysicalPlan visitPaginate(LogicalPaginate plan, C context) { diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java index 2a886ba0ca..13c6d7a979 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanDSL.java @@ -15,6 +15,8 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.NamedExpression; @@ -130,6 +132,11 @@ public static LogicalPlan rareTopN( return new LogicalRareTopN(input, commandType, noOfResults, Arrays.asList(fields), groupByList); } + public static LogicalTrendline trendline( + LogicalPlan input, Pair... computations) { + return new LogicalTrendline(input, Arrays.asList(computations)); + } + @SafeVarargs public LogicalPlan values(List... values) { return new LogicalValues(Arrays.asList(values)); diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java index 156db35306..c9eedd8efc 100644 --- a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitor.java @@ -104,6 +104,10 @@ public R visitAD(LogicalAD plan, C context) { return visitNode(plan, context); } + public R visitTrendline(LogicalTrendline plan, C context) { + return visitNode(plan, context); + } + public R visitPaginate(LogicalPaginate plan, C context) { return visitNode(plan, context); } diff --git a/core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java new file mode 100644 index 0000000000..3e992035e2 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/logical/LogicalTrendline.java @@ -0,0 +1,42 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.logical; + +import java.util.Collections; +import java.util.List; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.data.type.ExprCoreType; + +/* + * Trendline logical plan. + */ +@Getter +@ToString +@EqualsAndHashCode(callSuper = true) +public class LogicalTrendline extends LogicalPlan { + private final List> computations; + + /** + * Constructor of LogicalTrendline. + * + * @param child child logical plan + * @param computations the computations for this trendline call. + */ + public LogicalTrendline( + LogicalPlan child, List> computations) { + super(Collections.singletonList(child)); + this.computations = computations; + } + + @Override + public R accept(LogicalPlanNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } +} diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java index 67d7a05135..66c7219e39 100644 --- a/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java +++ b/core/src/main/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitor.java @@ -96,6 +96,10 @@ public R visitML(PhysicalPlan node, C context) { return visitNode(node, context); } + public R visitTrendline(TrendlineOperator node, C context) { + return visitNode(node, context); + } + public R visitCursorClose(CursorCloseOperator node, C context) { return visitNode(node, context); } diff --git a/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java new file mode 100644 index 0000000000..7bf10964cf --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/planner/physical/TrendlineOperator.java @@ -0,0 +1,317 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static java.time.temporal.ChronoUnit.MILLIS; + +import com.google.common.collect.EvictingQueue; +import com.google.common.collect.ImmutableMap.Builder; +import java.time.Instant; +import java.time.LocalTime; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.data.model.ExprTupleValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.LiteralExpression; + +/** Trendline command implementation */ +@ToString +@EqualsAndHashCode(callSuper = false) +public class TrendlineOperator extends PhysicalPlan { + @Getter private final PhysicalPlan input; + @Getter private final List> computations; + @EqualsAndHashCode.Exclude private final List accumulators; + @EqualsAndHashCode.Exclude private final Map fieldToIndexMap; + @EqualsAndHashCode.Exclude private final HashSet aliases; + + public TrendlineOperator( + PhysicalPlan input, List> computations) { + this.input = input; + this.computations = computations; + this.accumulators = computations.stream().map(TrendlineOperator::createAccumulator).toList(); + fieldToIndexMap = new HashMap<>(computations.size()); + aliases = new HashSet<>(computations.size()); + for (int i = 0; i < computations.size(); ++i) { + final Trendline.TrendlineComputation computation = computations.get(i).getKey(); + fieldToIndexMap.put(computation.getDataField().getChild().get(0).toString(), i); + aliases.add(computation.getAlias()); + } + } + + @Override + public R accept(PhysicalPlanNodeVisitor visitor, C context) { + return visitor.visitTrendline(this, context); + } + + @Override + public List getChild() { + return Collections.singletonList(input); + } + + @Override + public boolean hasNext() { + return getChild().getFirst().hasNext(); + } + + @Override + public ExprValue next() { + final ExprValue result; + final ExprValue next = input.next(); + final Map inputStruct = consumeInputTuple(next); + final Builder mapBuilder = new Builder<>(); + mapBuilder.putAll(inputStruct); + + // Add calculated trendline values, which might overwrite existing fields from the input. + for (int i = 0; i < accumulators.size(); ++i) { + final ExprValue calculateResult = accumulators.get(i).calculate(); + final String field = computations.get(i).getKey().getAlias(); + if (calculateResult != null) { + mapBuilder.put(field, calculateResult); + } + } + + result = ExprTupleValue.fromExprValueMap(mapBuilder.buildKeepingLast()); + return result; + } + + private Map consumeInputTuple(ExprValue inputValue) { + final Map tupleValue = ExprValueUtils.getTupleValue(inputValue); + for (String bindName : tupleValue.keySet()) { + final Integer index = fieldToIndexMap.get(bindName); + if (index != null) { + final ExprValue fieldValue = tupleValue.get(bindName); + if (!fieldValue.isNull()) { + accumulators.get(index).accumulate(fieldValue); + } + } + } + tupleValue.keySet().removeAll(aliases); + return tupleValue; + } + + private static TrendlineAccumulator createAccumulator( + Pair computation) { + // Add a switch statement based on computation type to choose the accumulator when more + // types of computations are supported. + return new SimpleMovingAverageAccumulator(computation.getKey(), computation.getValue()); + } + + /** Maintains stateful information for calculating the trendline. */ + private interface TrendlineAccumulator { + void accumulate(ExprValue value); + + ExprValue calculate(); + + static ArithmeticEvaluator getEvaluator(ExprCoreType type) { + switch (type) { + case DOUBLE: + return NumericArithmeticEvaluator.INSTANCE; + case DATE: + return DateArithmeticEvaluator.INSTANCE; + case TIME: + return TimeArithmeticEvaluator.INSTANCE; + case TIMESTAMP: + return TimestampArithmeticEvaluator.INSTANCE; + } + throw new IllegalArgumentException( + String.format("Invalid type %s used for moving average.", type.typeName())); + } + } + + private static class SimpleMovingAverageAccumulator implements TrendlineAccumulator { + private final LiteralExpression dataPointsNeeded; + private final EvictingQueue receivedValues; + private final ArithmeticEvaluator evaluator; + private Expression runningTotal = null; + + public SimpleMovingAverageAccumulator( + Trendline.TrendlineComputation computation, ExprCoreType type) { + dataPointsNeeded = DSL.literal(computation.getNumberOfDataPoints().doubleValue()); + receivedValues = EvictingQueue.create(computation.getNumberOfDataPoints()); + evaluator = TrendlineAccumulator.getEvaluator(type); + } + + @Override + public void accumulate(ExprValue value) { + if (dataPointsNeeded.valueOf().integerValue() == 1) { + runningTotal = evaluator.calculateFirstTotal(Collections.singletonList(value)); + receivedValues.add(value); + return; + } + + final ExprValue valueToRemove; + if (receivedValues.size() == dataPointsNeeded.valueOf().integerValue()) { + valueToRemove = receivedValues.remove(); + } else { + valueToRemove = null; + } + receivedValues.add(value); + + if (receivedValues.size() == dataPointsNeeded.valueOf().integerValue()) { + if (runningTotal != null) { + // We can use the previous calculation. + // Subtract the evicted value and add the new value. + // Refactored, that would be previous + (newValue - oldValue). + runningTotal = evaluator.add(runningTotal, value, valueToRemove); + } else { + // This is the first average calculation so sum the entire receivedValues dataset. + final List data = receivedValues.stream().toList(); + runningTotal = evaluator.calculateFirstTotal(data); + } + } + } + + @Override + public ExprValue calculate() { + if (receivedValues.size() < dataPointsNeeded.valueOf().integerValue()) { + return null; + } else if (dataPointsNeeded.valueOf().integerValue() == 1) { + return receivedValues.peek(); + } + return evaluator.evaluate(runningTotal, dataPointsNeeded); + } + } + + private interface ArithmeticEvaluator { + Expression calculateFirstTotal(List dataPoints); + + Expression add(Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue); + + ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints); + } + + private static class NumericArithmeticEvaluator implements ArithmeticEvaluator { + private static final NumericArithmeticEvaluator INSTANCE = new NumericArithmeticEvaluator(); + + private NumericArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0.0D); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(dataPoint.doubleValue())); + } + return DSL.literal(total.valueOf().doubleValue()); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add(runningTotal, DSL.subtract(DSL.literal(incomingValue), DSL.literal(evictedValue))) + .valueOf() + .doubleValue()); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return DSL.divide(runningTotal, numberOfDataPoints).valueOf(); + } + } + + private static class DateArithmeticEvaluator implements ArithmeticEvaluator { + private static final DateArithmeticEvaluator INSTANCE = new DateArithmeticEvaluator(); + + private DateArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + return TimestampArithmeticEvaluator.INSTANCE.calculateFirstTotal(dataPoints); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return TimestampArithmeticEvaluator.INSTANCE.add(runningTotal, incomingValue, evictedValue); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + final ExprValue timestampResult = + TimestampArithmeticEvaluator.INSTANCE.evaluate(runningTotal, numberOfDataPoints); + return ExprValueUtils.dateValue(timestampResult.dateValue()); + } + } + + private static class TimeArithmeticEvaluator implements ArithmeticEvaluator { + private static final TimeArithmeticEvaluator INSTANCE = new TimeArithmeticEvaluator(); + + private TimeArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(MILLIS.between(LocalTime.MIN, dataPoint.timeValue()))); + } + return DSL.literal(total.valueOf().longValue()); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add( + runningTotal, + DSL.subtract( + DSL.literal(MILLIS.between(LocalTime.MIN, incomingValue.timeValue())), + DSL.literal(MILLIS.between(LocalTime.MIN, evictedValue.timeValue())))) + .valueOf()); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return ExprValueUtils.timeValue( + LocalTime.MIN.plus( + DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue(), MILLIS)); + } + } + + private static class TimestampArithmeticEvaluator implements ArithmeticEvaluator { + private static final TimestampArithmeticEvaluator INSTANCE = new TimestampArithmeticEvaluator(); + + private TimestampArithmeticEvaluator() {} + + @Override + public Expression calculateFirstTotal(List dataPoints) { + Expression total = DSL.literal(0); + for (ExprValue dataPoint : dataPoints) { + total = DSL.add(total, DSL.literal(dataPoint.timestampValue().toEpochMilli())); + } + return DSL.literal(total.valueOf().longValue()); + } + + @Override + public Expression add( + Expression runningTotal, ExprValue incomingValue, ExprValue evictedValue) { + return DSL.literal( + DSL.add( + runningTotal, + DSL.subtract( + DSL.literal(incomingValue.timestampValue().toEpochMilli()), + DSL.literal(evictedValue.timestampValue().toEpochMilli()))) + .valueOf()); + } + + @Override + public ExprValue evaluate(Expression runningTotal, LiteralExpression numberOfDataPoints) { + return ExprValueUtils.timestampValue( + Instant.ofEpochMilli(DSL.divide(runningTotal, numberOfDataPoints).valueOf().longValue())); + } + } +} diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 4f06ce9d23..d6cb0544d8 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -18,6 +18,7 @@ import static org.opensearch.sql.ast.dsl.AstDSL.argument; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.compare; +import static org.opensearch.sql.ast.dsl.AstDSL.computation; import static org.opensearch.sql.ast.dsl.AstDSL.field; import static org.opensearch.sql.ast.dsl.AstDSL.filter; import static org.opensearch.sql.ast.dsl.AstDSL.filteredAggregate; @@ -33,6 +34,7 @@ import static org.opensearch.sql.ast.tree.Sort.SortOption; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; import static org.opensearch.sql.ast.tree.Sort.SortOrder; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; import static org.opensearch.sql.data.model.ExprValueUtils.stringValue; import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN; @@ -66,6 +68,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.junit.jupiter.api.Disabled; @@ -89,6 +92,7 @@ import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.antlr.SyntaxCheckException; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.exception.SemanticCheckException; import org.opensearch.sql.expression.DSL; @@ -1481,6 +1485,62 @@ public void fillnull_various_values() { AstDSL.field("int_null_value"), AstDSL.intLiteral(1)))))); } + @Test + public void trendline() { + assertAnalyzeEqual( + LogicalPlanDSL.trendline( + LogicalPlanDSL.relation("schema", table), + Pair.of(computation(5, field("float_value"), "test_field_alias", SMA), DOUBLE), + Pair.of(computation(1, field("double_value"), "test_field_alias_2", SMA), DOUBLE)), + AstDSL.trendline( + AstDSL.relation("schema"), + Optional.empty(), + computation(5, field("float_value"), "test_field_alias", SMA), + computation(1, field("double_value"), "test_field_alias_2", SMA))); + } + + @Test + public void trendline_datetime_types() { + assertAnalyzeEqual( + LogicalPlanDSL.trendline( + LogicalPlanDSL.relation("schema", table), + Pair.of(computation(5, field("timestamp_value"), "test_field_alias", SMA), TIMESTAMP)), + AstDSL.trendline( + AstDSL.relation("schema"), + Optional.empty(), + computation(5, field("timestamp_value"), "test_field_alias", SMA))); + } + + @Test + public void trendline_illegal_type() { + assertThrows( + SemanticCheckException.class, + () -> + analyze( + AstDSL.trendline( + AstDSL.relation("schema"), + Optional.empty(), + computation(5, field("array_value"), "test_field_alias", SMA)))); + } + + @Test + public void trendline_with_sort() { + assertAnalyzeEqual( + LogicalPlanDSL.trendline( + LogicalPlanDSL.sort( + LogicalPlanDSL.relation("schema", table), + Pair.of( + new SortOption(SortOrder.ASC, NullOrder.NULL_FIRST), + DSL.ref("float_value", ExprCoreType.FLOAT))), + Pair.of(computation(5, field("float_value"), "test_field_alias", SMA), DOUBLE), + Pair.of(computation(1, field("double_value"), "test_field_alias_2", SMA), DOUBLE)), + AstDSL.trendline( + AstDSL.relation("schema"), + Optional.of(field("float_value", argument("asc", booleanLiteral(true)))), + computation(5, field("float_value"), "test_field_alias", SMA), + computation(1, field("double_value"), "test_field_alias_2", SMA))); + } + @Test public void ad_batchRCF_relation() { Map argumentMap = diff --git a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java index eaeae07242..febf662843 100644 --- a/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java +++ b/core/src/test/java/org/opensearch/sql/executor/ExplainTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.opensearch.sql.ast.tree.RareTopN.CommandType.TOP; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; @@ -31,6 +32,8 @@ import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.values; import static org.opensearch.sql.planner.physical.PhysicalPlanDSL.window; +import com.google.common.collect.ImmutableMap; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.Set; @@ -39,6 +42,7 @@ import org.junit.jupiter.api.DisplayNameGeneration; import org.junit.jupiter.api.DisplayNameGenerator; import org.junit.jupiter.api.Test; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.data.model.ExprValue; import org.opensearch.sql.executor.ExecutionEngine.ExplainResponse; @@ -52,6 +56,7 @@ import org.opensearch.sql.expression.aggregation.NamedAggregator; import org.opensearch.sql.expression.window.WindowDefinition; import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.planner.physical.TrendlineOperator; import org.opensearch.sql.storage.TableScanOperator; @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @@ -256,6 +261,44 @@ void can_explain_nested() { explain.apply(plan)); } + @Test + void can_explain_trendline() { + PhysicalPlan plan = + new TrendlineOperator( + tableScan, + Arrays.asList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), DOUBLE), + Pair.of(AstDSL.computation(3, AstDSL.field("time"), "time_alias", SMA), DOUBLE))); + assertEquals( + new ExplainResponse( + new ExplainResponseNode( + "TrendlineOperator", + ImmutableMap.of( + "computations", + List.of( + ImmutableMap.of( + "computationType", + "sma", + "numberOfDataPoints", + "2", + "dataField", + "distance", + "alias", + "distance_alias"), + ImmutableMap.of( + "computationType", + "sma", + "numberOfDataPoints", + "3", + "dataField", + "time", + "alias", + "time_alias"))), + singletonList(tableScan.explainNode()))), + explain.apply(plan)); + } + private static class FakeTableScan extends TableScanOperator { @Override public boolean hasNext() { diff --git a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java index 8e71fc2bec..8ee0dd7e70 100644 --- a/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/DefaultImplementorTest.java @@ -7,11 +7,13 @@ import static java.util.Collections.emptyList; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.literal; @@ -44,8 +46,10 @@ import org.junit.jupiter.api.extension.ExtendWith; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.executor.pagination.PlanSerializer; @@ -63,11 +67,13 @@ import org.opensearch.sql.planner.logical.LogicalPlanDSL; import org.opensearch.sql.planner.logical.LogicalProject; import org.opensearch.sql.planner.logical.LogicalRelation; +import org.opensearch.sql.planner.logical.LogicalTrendline; import org.opensearch.sql.planner.logical.LogicalValues; import org.opensearch.sql.planner.physical.CursorCloseOperator; import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; import org.opensearch.sql.planner.physical.ProjectOperator; +import org.opensearch.sql.planner.physical.TrendlineOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; @@ -304,4 +310,19 @@ public void visitLimit_support_return_takeOrdered() { 5); assertEquals(physicalPlanTree, logicalLimit.accept(implementor, null)); } + + @Test + public void visitTrendline_should_build_TrendlineOperator() { + var logicalChild = mock(LogicalPlan.class); + var physicalChild = mock(PhysicalPlan.class); + when(logicalChild.accept(implementor, null)).thenReturn(physicalChild); + final Trendline.TrendlineComputation computation = + AstDSL.computation(1, AstDSL.field("field"), "alias", SMA); + var logicalPlan = + new LogicalTrendline( + logicalChild, Collections.singletonList(Pair.of(computation, ExprCoreType.DOUBLE))); + var implemented = logicalPlan.accept(implementor, null); + assertInstanceOf(TrendlineOperator.class, implemented); + assertSame(physicalChild, implemented.getChild().get(0)); + } } diff --git a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java index f212749f48..43ce23ed56 100644 --- a/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/logical/LogicalPlanNodeVisitorTest.java @@ -8,6 +8,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.Mockito.mock; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.named; @@ -25,9 +26,11 @@ import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; import org.junit.jupiter.params.provider.MethodSource; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.Expression; import org.opensearch.sql.expression.LiteralExpression; @@ -141,6 +144,14 @@ public TableWriteOperator build(PhysicalPlan child) { LogicalCloseCursor closeCursor = new LogicalCloseCursor(cursor); + LogicalTrendline trendline = + new LogicalTrendline( + relation, + Collections.singletonList( + Pair.of( + AstDSL.computation(1, AstDSL.field("testField"), "dummy", SMA), + ExprCoreType.DOUBLE))); + return Stream.of( relation, tableScanBuilder, @@ -163,7 +174,8 @@ public TableWriteOperator build(PhysicalPlan child) { paginate, nested, cursor, - closeCursor) + closeCursor, + trendline) .map(Arguments::of); } diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java index 17fb128ace..26f288e6b6 100644 --- a/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/physical/PhysicalPlanNodeVisitorTest.java @@ -9,6 +9,7 @@ import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNull; import static org.mockito.Mockito.mock; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.expression.DSL.named; @@ -29,6 +30,7 @@ import com.google.common.base.Strings; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -43,6 +45,7 @@ import org.junit.jupiter.params.provider.MethodSource; import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.expression.DSL; @@ -65,7 +68,16 @@ public void print_physical_plan() { agg( rareTopN( filter( - limit(new TestScan(), 1, 1), + limit( + new TrendlineOperator( + new TestScan(), + Collections.singletonList( + Pair.of( + AstDSL.computation( + 1, AstDSL.field("field"), "alias", SMA), + DOUBLE))), + 1, + 1), DSL.equal(DSL.ref("response", INTEGER), DSL.literal(10))), CommandType.TOP, ImmutableList.of(), @@ -85,7 +97,8 @@ public void print_physical_plan() { + "\t\t\tAggregation->\n" + "\t\t\t\tRareTopN->\n" + "\t\t\t\t\tFilter->\n" - + "\t\t\t\t\t\tLimit->", + + "\t\t\t\t\t\tLimit->\n" + + "\t\t\t\t\t\t\tTrendline->", printer.print(plan)); } @@ -134,6 +147,12 @@ public static Stream getPhysicalPlanForTest() { PhysicalPlan cursorClose = new CursorCloseOperator(plan); + PhysicalPlan trendline = + new TrendlineOperator( + plan, + Collections.singletonList( + Pair.of(AstDSL.computation(1, AstDSL.field("field"), "alias", SMA), DOUBLE))); + return Stream.of( Arguments.of(filter, "filter"), Arguments.of(aggregation, "aggregation"), @@ -149,7 +168,8 @@ public static Stream getPhysicalPlanForTest() { Arguments.of(rareTopN, "rareTopN"), Arguments.of(limit, "limit"), Arguments.of(nested, "nested"), - Arguments.of(cursorClose, "cursorClose")); + Arguments.of(cursorClose, "cursorClose"), + Arguments.of(trendline, "trendline")); } @ParameterizedTest(name = "{1}") @@ -223,6 +243,11 @@ public String visitLimit(LimitOperator node, Integer tabs) { return name(node, "Limit->", tabs); } + @Override + public String visitTrendline(TrendlineOperator node, Integer tabs) { + return name(node, "Trendline->", tabs); + } + private String name(PhysicalPlan node, String current, int tabs) { String child = node.getChild().get(0).accept(this, tabs + 1); StringBuilder sb = new StringBuilder(); diff --git a/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java new file mode 100644 index 0000000000..ef2c2907ce --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/planner/physical/TrendlineOperatorTest.java @@ -0,0 +1,398 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.planner.physical; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; + +import com.google.common.collect.ImmutableMap; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalTime; +import java.util.Arrays; +import java.util.Collections; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.DisplayNameGeneration; +import org.junit.jupiter.api.DisplayNameGenerator; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.ast.dsl.AstDSL; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; + +@DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) +@ExtendWith(MockitoExtension.class) +public class TrendlineOperatorTest { + @Mock private PhysicalPlan inputPlan; + + @Test + public void calculates_simple_moving_average_one_field_one_sample() { + when(inputPlan.hasNext()).thenReturn(true, false); + when(inputPlan.next()) + .thenReturn(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.DOUBLE))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), + plan.next()); + } + + @Test + public void calculates_simple_moving_average_one_field_two_samples() { + when(inputPlan.hasNext()).thenReturn(true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.DOUBLE))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_simple_moving_average_one_field_two_samples_three_rows() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.DOUBLE))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 150.0)), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 200, "time", 10, "distance_alias", 200.0)), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_simple_moving_average_multiple_computations() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 20))); + + var plan = + new TrendlineOperator( + inputPlan, + Arrays.asList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.DOUBLE), + Pair.of( + AstDSL.computation(2, AstDSL.field("time"), "time_alias", SMA), + ExprCoreType.DOUBLE))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "distance", 200, "time", 20, "distance_alias", 150.0, "time_alias", 15.0)), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "distance", 200, "time", 20, "distance_alias", 200.0, "time_alias", 20.0)), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void alias_overwrites_input_field() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "time", SMA), + ExprCoreType.DOUBLE))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 150.0)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 200.0)), plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_simple_moving_average_one_field_two_samples_three_rows_null_value() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 300, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.DOUBLE))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 200, "time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 300, "time", 10, "distance_alias", 250.0)), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void use_null_value() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), + ExprValueUtils.tupleValue(ImmutableMap.of("distance", 100, "time", 10))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(1, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.DOUBLE))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", 10)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("distance", ExprNullValue.of(), "time", 10)), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("distance", 100, "time", 10, "distance_alias", 100)), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void use_illegal_core_type() { + assertThrows( + IllegalArgumentException.class, + () -> { + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("distance"), "distance_alias", SMA), + ExprCoreType.ARRAY))); + }); + } + + @Test + public void calculates_simple_moving_average_date() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)))), + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12))))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("date"), "date_alias", SMA), + ExprCoreType.DATE))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of("date", ExprValueUtils.dateValue(LocalDate.EPOCH))), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(6)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(3)))), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "date", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(12)), + "date_alias", + ExprValueUtils.dateValue(LocalDate.EPOCH.plusDays(9)))), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_simple_moving_average_time() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN))), + ExprValueUtils.tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(6)))), + ExprValueUtils.tupleValue( + ImmutableMap.of("time", ExprValueUtils.timeValue(LocalTime.MIN.plusHours(12))))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("time"), "time_alias", SMA), + ExprCoreType.TIME))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals(ExprValueUtils.tupleValue(ImmutableMap.of("time", LocalTime.MIN)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "time", LocalTime.MIN.plusHours(6), "time_alias", LocalTime.MIN.plusHours(3))), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "time", LocalTime.MIN.plusHours(12), "time_alias", LocalTime.MIN.plusHours(9))), + plan.next()); + assertFalse(plan.hasNext()); + } + + @Test + public void calculates_simple_moving_average_timestamp() { + when(inputPlan.hasNext()).thenReturn(true, true, true, false); + when(inputPlan.next()) + .thenReturn( + ExprValueUtils.tupleValue( + ImmutableMap.of("timestamp", ExprValueUtils.timestampValue(Instant.EPOCH))), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1000)))), + ExprValueUtils.tupleValue( + ImmutableMap.of( + "timestamp", ExprValueUtils.timestampValue(Instant.EPOCH.plusMillis(1500))))); + + var plan = + new TrendlineOperator( + inputPlan, + Collections.singletonList( + Pair.of( + AstDSL.computation(2, AstDSL.field("timestamp"), "timestamp_alias", SMA), + ExprCoreType.TIMESTAMP))); + + plan.open(); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue(ImmutableMap.of("timestamp", Instant.EPOCH)), plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1000), + "timestamp_alias", + Instant.EPOCH.plusMillis(500))), + plan.next()); + assertTrue(plan.hasNext()); + assertEquals( + ExprValueUtils.tupleValue( + ImmutableMap.of( + "timestamp", + Instant.EPOCH.plusMillis(1500), + "timestamp_alias", + Instant.EPOCH.plusMillis(1250))), + plan.next()); + assertFalse(plan.hasNext()); + } +} diff --git a/docs/category.json b/docs/category.json index aacfc43478..32f56cfb46 100644 --- a/docs/category.json +++ b/docs/category.json @@ -25,6 +25,7 @@ "user/ppl/cmd/sort.rst", "user/ppl/cmd/stats.rst", "user/ppl/cmd/syntax.rst", + "user/ppl/cmd/trendline.rst", "user/ppl/cmd/top.rst", "user/ppl/cmd/where.rst", "user/ppl/general/identifiers.rst", diff --git a/docs/user/ppl/cmd/trendline.rst b/docs/user/ppl/cmd/trendline.rst new file mode 100644 index 0000000000..166a3c056f --- /dev/null +++ b/docs/user/ppl/cmd/trendline.rst @@ -0,0 +1,90 @@ +============= +trendline +============= + +.. rubric:: Table of contents + +.. contents:: + :local: + :depth: 2 + + +Description +============ +| Using ``trendline`` command to calculate moving averages of fields. + +Syntax +============ +`TRENDLINE [sort <[+|-] sort-field>] SMA(number-of-datapoints, field) [AS alias] [SMA(number-of-datapoints, field) [AS alias]]...` + +* [+|-]: optional. The plus [+] stands for ascending order and NULL/MISSING first and a minus [-] stands for descending order and NULL/MISSING last. **Default:** ascending order and NULL/MISSING first. +* sort-field: mandatory when sorting is used. The field used to sort. +* number-of-datapoints: mandatory. The number of datapoints to calculate the moving average (must be greater than zero). +* field: mandatory. The name of the field the moving average should be calculated for. +* alias: optional. The name of the resulting column containing the moving average (defaults to the field name with "_trendline"). + +And the moment only the Simple Moving Average (SMA) type is supported. + +It is calculated like + + f[i]: The value of field 'f' in the i-th data-point + n: The number of data-points in the moving window (period) + t: The current time index + + SMA(t) = (1/n) * Σ(f[i]), where i = t-n+1 to t + +Example 1: Calculate the moving average on one field. +===================================================== + +The example shows how to calculate the moving average on one field. + +PPL query:: + + os> source=accounts | trendline sma(2, account_number) as an | fields an; + fetched rows / total rows = 4/4 + +------+ + | an | + |------| + | null | + | 3.5 | + | 9.5 | + | 15.5 | + +------+ + + +Example 2: Calculate the moving average on multiple fields. +=========================================================== + +The example shows how to calculate the moving average on multiple fields. + +PPL query:: + + os> source=accounts | trendline sma(2, account_number) as an sma(2, age) as age_trend | fields an, age_trend ; + fetched rows / total rows = 4/4 + +------+-----------+ + | an | age_trend | + |------+-----------| + | null | null | + | 3.5 | 34.0 | + | 9.5 | 32.0 | + | 15.5 | 30.5 | + +------+-----------+ + +Example 4: Calculate the moving average on one field without specifying an alias. +================================================================================= + +The example shows how to calculate the moving average on one field. + +PPL query:: + + os> source=accounts | trendline sma(2, account_number) | fields account_number_trendline; + fetched rows / total rows = 4/4 + +--------------------------+ + | account_number_trendline | + |--------------------------| + | null | + | 3.5 | + | 9.5 | + | 15.5 | + +--------------------------+ + diff --git a/docs/user/ppl/index.rst b/docs/user/ppl/index.rst index 9525874c59..ef8cff334e 100644 --- a/docs/user/ppl/index.rst +++ b/docs/user/ppl/index.rst @@ -74,6 +74,8 @@ The query start with search command and then flowing a set of command delimited - `stats command `_ + - `trendline command `_ + - `where command `_ - `head command `_ diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java index b9c7f89ba0..531a24bad6 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/ExplainIT.java @@ -100,6 +100,32 @@ public void testFillNullPushDownExplain() throws Exception { + " | fillnull with -1 in age,balance | fields age, balance")); } + @Test + public void testTrendlinePushDownExplain() throws Exception { + String expected = loadFromFile("expectedOutput/ppl/explain_trendline_push.json"); + + assertJsonEquals( + expected, + explainQueryToString( + "source=opensearch-sql_test_index_account" + + "| head 5 " + + "| trendline sma(2, age) as ageTrend " + + "| fields ageTrend")); + } + + @Test + public void testTrendlineWithSortPushDownExplain() throws Exception { + String expected = loadFromFile("expectedOutput/ppl/explain_trendline_sort_push.json"); + + assertJsonEquals( + expected, + explainQueryToString( + "source=opensearch-sql_test_index_account" + + "| head 5 " + + "| trendline sort age sma(2, age) as ageTrend " + + "| fields ageTrend")); + } + String loadFromFile(String filename) throws Exception { URI uri = Resources.getResource(filename).toURI(); return new String(Files.readAllBytes(Paths.get(uri))); diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java new file mode 100644 index 0000000000..38baa0f01f --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/TrendlineCommandIT.java @@ -0,0 +1,78 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; +import static org.opensearch.sql.util.MatcherUtils.rows; +import static org.opensearch.sql.util.MatcherUtils.verifyDataRows; + +import java.io.IOException; +import org.json.JSONObject; +import org.junit.jupiter.api.Test; + +public class TrendlineCommandIT extends PPLIntegTestCase { + + @Override + public void init() throws IOException { + loadIndex(Index.BANK); + } + + @Test + public void testTrendline() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as" + + " balance_trend | fields balance_trend", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); + } + + @Test + public void testTrendlineMultipleFields() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as" + + " balance_trend sma(2, account_number) as account_number_trend | fields" + + " balance_trend, account_number_trend", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(null, null), rows(44313.0, 28.5), rows(39882.5, 13.0)); + } + + @Test + public void testTrendlineOverwritesExistingField() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) as" + + " age | fields age", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); + } + + @Test + public void testTrendlineNoAlias() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | sort balance | trendline sma(2, balance) |" + + " fields balance_trendline", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); + } + + @Test + public void testTrendlineWithSort() throws IOException { + final JSONObject result = + executeQuery( + String.format( + "source=%s | where balance > 39000 | trendline sort balance sma(2, balance) |" + + " fields balance_trendline", + TEST_INDEX_BANK)); + verifyDataRows(result, rows(new Object[] {null}), rows(44313.0), rows(39882.5)); + } +} diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_push.json new file mode 100644 index 0000000000..754535dc32 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_push.json @@ -0,0 +1,32 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[ageTrend]" + }, + "children": [ + { + "name": "TrendlineOperator", + "description": { + "computations": [ + { + "computationType" : "sma", + "numberOfDataPoints" : "2", + "dataField" : "age", + "alias" : "ageTrend" + } + ] + }, + "children": [ + { + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName=opensearch-sql_test_index_account, sourceBuilder={\"from\":0,\"size\":5,\"timeout\":\"1m\"}, needClean\u003dtrue, searchDone\u003dfalse, pitId\u003dnull, cursorKeepAlive\u003dnull, searchAfter\u003dnull, searchResponse\u003dnull)" + }, + "children": [] + } + ] + } + ] + } +} diff --git a/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_sort_push.json b/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_sort_push.json new file mode 100644 index 0000000000..6629434108 --- /dev/null +++ b/integ-test/src/test/resources/expectedOutput/ppl/explain_trendline_sort_push.json @@ -0,0 +1,32 @@ +{ + "root": { + "name": "ProjectOperator", + "description": { + "fields": "[ageTrend]" + }, + "children": [ + { + "name": "TrendlineOperator", + "description": { + "computations": [ + { + "computationType" : "sma", + "numberOfDataPoints" : "2", + "dataField" : "age", + "alias" : "ageTrend" + } + ] + }, + "children": [ + { + "name": "OpenSearchIndexScan", + "description": { + "request": "OpenSearchQueryRequest(indexName\u003dopensearch-sql_test_index_account, sourceBuilder\u003d{\"from\":0,\"size\":5,\"timeout\":\"1m\",\"sort\":[{\"age\":{\"order\":\"asc\",\"missing\":\"_first\"}}]}, needClean\u003dtrue, searchDone\u003dfalse, pitId\u003dnull, cursorKeepAlive\u003dnull, searchAfter\u003dnull, searchResponse\u003dnull)" + }, + "children": [] + } + ] + } + ] + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java index 28827b0a54..358bc10ab4 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtector.java @@ -24,6 +24,7 @@ import org.opensearch.sql.planner.physical.RenameOperator; import org.opensearch.sql.planner.physical.SortOperator; import org.opensearch.sql.planner.physical.TakeOrderedOperator; +import org.opensearch.sql.planner.physical.TrendlineOperator; import org.opensearch.sql.planner.physical.ValuesOperator; import org.opensearch.sql.planner.physical.WindowOperator; import org.opensearch.sql.storage.TableScanOperator; @@ -187,6 +188,12 @@ public PhysicalPlan visitML(PhysicalPlan node, Object context) { mlOperator.getNodeClient())); } + @Override + public PhysicalPlan visitTrendline(TrendlineOperator node, Object context) { + return doProtect( + new TrendlineOperator(visitInput(node.getInput(), context), node.getComputations())); + } + PhysicalPlan visitInput(PhysicalPlan node, Object context) { if (null == node) { return node; diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java index da06c1eb66..724178bd34 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/executor/protector/OpenSearchExecutionProtectorTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertSame; import static org.mockito.Mockito.*; import static org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; @@ -23,6 +24,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -37,10 +39,12 @@ import org.mockito.Mock; import org.mockito.junit.jupiter.MockitoExtension; import org.opensearch.client.node.NodeClient; +import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.DataType; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.tree.RareTopN.CommandType; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.common.setting.Settings; import org.opensearch.sql.data.model.ExprBooleanValue; import org.opensearch.sql.expression.DSL; @@ -67,6 +71,7 @@ import org.opensearch.sql.planner.physical.PhysicalPlan; import org.opensearch.sql.planner.physical.PhysicalPlanDSL; import org.opensearch.sql.planner.physical.TakeOrderedOperator; +import org.opensearch.sql.planner.physical.TrendlineOperator; @ExtendWith(MockitoExtension.class) @DisplayNameGeneration(DisplayNameGenerator.ReplaceUnderscores.class) @@ -318,6 +323,22 @@ public void test_visitTakeOrdered() { resourceMonitor(takeOrdered), executionProtector.visitTakeOrdered(takeOrdered, null)); } + @Test + public void test_visitTrendline() { + final TrendlineOperator trendlineOperator = + new TrendlineOperator( + PhysicalPlanDSL.values(emptyList()), + Collections.singletonList( + Pair.of( + new Trendline.TrendlineComputation( + 1, AstDSL.field("dummy"), "dummy_alias", SMA), + DOUBLE))); + + assertEquals( + resourceMonitor(trendlineOperator), + executionProtector.visitTrendline(trendlineOperator, null)); + } + PhysicalPlan resourceMonitor(PhysicalPlan input) { return new ResourceMonitorPlan(input, resourceMonitor); } diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index 3ba8da74f4..4a883fa656 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -36,6 +36,7 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; +TRENDLINE: 'TRENDLINE'; // COMMAND ASSIST KEYWORDS AS: 'AS'; @@ -57,6 +58,9 @@ STR: 'STR'; IP: 'IP'; NUM: 'NUM'; +// TRENDLINE KEYWORDS +SMA: 'SMA'; + // ARGUMENT KEYWORDS KEEPEMPTY: 'KEEPEMPTY'; CONSECUTIVE: 'CONSECUTIVE'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 89a32abe23..c9d0f2e110 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -50,6 +50,7 @@ commands | adCommand | mlCommand | fillnullCommand + | trendlineCommand ; searchCommand @@ -145,6 +146,18 @@ nullReplacementExpression : nullableField = fieldExpression EQUAL nullReplacement = valueExpression ; +trendlineCommand + : TRENDLINE (SORT sortField)? trendlineClause (trendlineClause)* + ; + +trendlineClause + : trendlineType LT_PRTHS numberOfDataPoints = integerLiteral COMMA field = fieldExpression RT_PRTHS (AS alias = qualifiedName)? + ; + +trendlineType + : SMA + ; + kmeansCommand : KMEANS (kmeansParameter)* ; @@ -876,6 +889,7 @@ keywordsCanBeId | KMEANS | AD | ML + | TRENDLINE // commands assist keywords | SOURCE | INDEX diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 2fccb8e635..c3c31ee2e1 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -64,6 +64,7 @@ import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; @@ -421,6 +422,21 @@ public UnresolvedPlan visitFillNullWithFieldVariousValues( FillNull.ContainNullableFieldFill.ofVariousValue(replacementsBuilder.build())); } + /** trendline command. */ + @Override + public UnresolvedPlan visitTrendlineCommand(OpenSearchPPLParser.TrendlineCommandContext ctx) { + List trendlineComputations = + ctx.trendlineClause().stream() + .map(expressionBuilder::visit) + .map(Trendline.TrendlineComputation.class::cast) + .collect(Collectors.toList()); + return Optional.ofNullable(ctx.sortField()) + .map(this::internalVisitExpression) + .map(Field.class::cast) + .map(sort -> new Trendline(Optional.of(sort), trendlineComputations)) + .orElse(new Trendline(Optional.empty(), trendlineComputations)); + } + /** Get original text in query. */ private String getTextInQuery(ParserRuleContext ctx) { Token start = ctx.getStart(); diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 98c41027ff..8bc98c8eee 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -45,6 +45,7 @@ import com.google.common.collect.ImmutableMap; import java.util.Arrays; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -52,6 +53,8 @@ import org.antlr.v4.runtime.RuleContext; import org.opensearch.sql.ast.dsl.AstDSL; import org.opensearch.sql.ast.expression.*; +import org.opensearch.sql.ast.tree.Trendline; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParserBaseVisitor; @@ -75,6 +78,28 @@ public UnresolvedExpression visitEvalClause(EvalClauseContext ctx) { return new Let((Field) visit(ctx.fieldExpression()), visit(ctx.expression())); } + /** Trendline clause. */ + @Override + public Trendline.TrendlineComputation visitTrendlineClause( + OpenSearchPPLParser.TrendlineClauseContext ctx) { + final int numberOfDataPoints = Integer.parseInt(ctx.numberOfDataPoints.getText()); + if (numberOfDataPoints < 1) { + throw new SyntaxCheckException( + "Number of trendline data-points must be greater than or equal to 1"); + } + + final Field dataField = (Field) this.visitFieldExpression(ctx.field); + final String alias = + ctx.alias != null + ? ctx.alias.getText() + : dataField.getChild().get(0).toString() + "_trendline"; + + final Trendline.TrendlineType computationType = + Trendline.TrendlineType.valueOf(ctx.trendlineType().getText().toUpperCase(Locale.ROOT)); + return new Trendline.TrendlineComputation( + numberOfDataPoints, dataField, alias, computationType); + } + /** Logical expression excluding boolean, comparison. */ @Override public UnresolvedExpression visitLogicalNot(LogicalNotContext ctx) { diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java index a1ca0fd69a..96e21eafcd 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizer.java @@ -9,6 +9,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.List; +import java.util.Locale; import java.util.stream.Collectors; import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; @@ -43,6 +44,7 @@ import org.opensearch.sql.ast.tree.Rename; import org.opensearch.sql.ast.tree.Sort; import org.opensearch.sql.ast.tree.TableFunction; +import org.opensearch.sql.ast.tree.Trendline; import org.opensearch.sql.ast.tree.UnresolvedPlan; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.planner.logical.LogicalAggregation; @@ -221,14 +223,26 @@ public String visitHead(Head node, String context) { return StringUtils.format("%s | head %d", child, size); } + @Override + public String visitTrendline(Trendline node, String context) { + String child = node.getChild().get(0).accept(this, context); + String computations = visitExpressionList(node.getComputations(), " "); + return StringUtils.format("%s | trendline %s", child, computations); + } + private String visitFieldList(List fieldList) { return fieldList.stream().map(this::visitExpression).collect(Collectors.joining(",")); } - private String visitExpressionList(List expressionList) { + private String visitExpressionList(List expressionList) { + return visitExpressionList(expressionList, ","); + } + + private String visitExpressionList( + List expressionList, String delimiter) { return expressionList.isEmpty() ? "" - : expressionList.stream().map(this::visitExpression).collect(Collectors.joining(",")); + : expressionList.stream().map(this::visitExpression).collect(Collectors.joining(delimiter)); } private String visitExpression(UnresolvedExpression expression) { @@ -344,5 +358,14 @@ public String visitAlias(Alias node, String context) { String expr = node.getDelegated().accept(this, context); return StringUtils.format("%s", expr); } + + @Override + public String visitTrendlineComputation(Trendline.TrendlineComputation node, String context) { + final String dataField = node.getDataField().accept(this, context); + final String aliasClause = " as " + node.getAlias(); + final String computationType = node.getComputationType().name().toLowerCase(Locale.ROOT); + return StringUtils.format( + "%s(%d, %s)%s", computationType, node.getNumberOfDataPoints(), dataField, aliasClause); + } } } diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java index ac2bce9dbc..c6f4ed2044 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstBuilderTest.java @@ -7,12 +7,14 @@ import static java.util.Collections.emptyList; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; import static org.opensearch.sql.ast.dsl.AstDSL.agg; import static org.opensearch.sql.ast.dsl.AstDSL.aggregate; import static org.opensearch.sql.ast.dsl.AstDSL.alias; import static org.opensearch.sql.ast.dsl.AstDSL.argument; import static org.opensearch.sql.ast.dsl.AstDSL.booleanLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.compare; +import static org.opensearch.sql.ast.dsl.AstDSL.computation; import static org.opensearch.sql.ast.dsl.AstDSL.dedupe; import static org.opensearch.sql.ast.dsl.AstDSL.defaultDedupArgs; import static org.opensearch.sql.ast.dsl.AstDSL.defaultFieldsArgs; @@ -38,13 +40,16 @@ import static org.opensearch.sql.ast.dsl.AstDSL.span; import static org.opensearch.sql.ast.dsl.AstDSL.stringLiteral; import static org.opensearch.sql.ast.dsl.AstDSL.tableFunction; +import static org.opensearch.sql.ast.dsl.AstDSL.trendline; import static org.opensearch.sql.ast.dsl.AstDSL.unresolvedArg; +import static org.opensearch.sql.ast.tree.Trendline.TrendlineType.SMA; import static org.opensearch.sql.utils.SystemIndexUtils.DATASOURCES_TABLE_NAME; import static org.opensearch.sql.utils.SystemIndexUtils.mappingTable; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Arrays; +import java.util.Optional; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -60,6 +65,7 @@ import org.opensearch.sql.ast.tree.Kmeans; import org.opensearch.sql.ast.tree.ML; import org.opensearch.sql.ast.tree.RareTopN.CommandType; +import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.ppl.antlr.PPLSyntaxParser; public class AstBuilderTest { @@ -692,6 +698,74 @@ public void testFillNullCommandVariousValues() { .build()))); } + public void testTrendline() { + assertEqual( + "source=t | trendline sma(5, test_field) as test_field_alias sma(1, test_field_2) as" + + " test_field_alias_2", + trendline( + relation("t"), + Optional.empty(), + computation(5, field("test_field"), "test_field_alias", SMA), + computation(1, field("test_field_2"), "test_field_alias_2", SMA))); + } + + @Test + public void testTrendlineSort() { + assertEqual( + "source=t | trendline sort test_field sma(5, test_field)", + trendline( + relation("t"), + Optional.of( + field( + "test_field", + argument("asc", booleanLiteral(true)), + argument("type", nullLiteral()))), + computation(5, field("test_field"), "test_field_trendline", SMA))); + } + + @Test + public void testTrendlineSortDesc() { + assertEqual( + "source=t | trendline sort - test_field sma(5, test_field)", + trendline( + relation("t"), + Optional.of( + field( + "test_field", + argument("asc", booleanLiteral(false)), + argument("type", nullLiteral()))), + computation(5, field("test_field"), "test_field_trendline", SMA))); + } + + @Test + public void testTrendlineSortAsc() { + assertEqual( + "source=t | trendline sort + test_field sma(5, test_field)", + trendline( + relation("t"), + Optional.of( + field( + "test_field", + argument("asc", booleanLiteral(true)), + argument("type", nullLiteral()))), + computation(5, field("test_field"), "test_field_trendline", SMA))); + } + + @Test + public void testTrendlineNoAlias() { + assertEqual( + "source=t | trendline sma(5, test_field)", + trendline( + relation("t"), + Optional.empty(), + computation(5, field("test_field"), "test_field_trendline", SMA))); + } + + @Test + public void testTrendlineTooFewSamples() { + assertThrows(SyntaxCheckException.class, () -> plan("source=t | trendline sma(0, test_field)")); + } + @Test public void testDescribeCommand() { assertEqual("describe t", relation(mappingTable("t"))); diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java index b5b4c97f13..06f8fbb061 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/utils/PPLQueryDataAnonymizerTest.java @@ -89,6 +89,13 @@ public void testDedupCommand() { anonymize("source=t | dedup f1, f2")); } + @Test + public void testTrendlineCommand() { + assertEquals( + "source=t | trendline sma(2, date) as date_alias sma(3, time) as time_alias", + anonymize("source=t | trendline sma(2, date) as date_alias sma(3, time) as time_alias")); + } + @Test public void testHeadCommandWithNumber() { assertEquals("source=t | head 3", anonymize("source=t | head 3"));