From 63b00ba097a76a1dd706e5159b601855e190b352 Mon Sep 17 00:00:00 2001 From: Mitchell Gale Date: Wed, 9 Aug 2023 15:03:27 -0700 Subject: [PATCH] [Spotless] Applying Google Code Format for core/src/main files #4 (#1933) * GJF part 4 Signed-off-by: Mitchell Gale * add build.gradle comment to mention why we are ignoring checkstyle failures for core. Signed-off-by: Mitchell Gale * Fix include spotless build gradle. Signed-off-by: Mitchell Gale * revert astDSL.JAVA Signed-off-by: Mitchell Gale * revert ast changes as was covered in spotless #1 PR for GJF. Signed-off-by: Mitchell Gale * Reverting commits in ast folder attempt #2 Signed-off-by: Mitchell Gale * revert change to RaretopN.java Signed-off-by: Mitchell Gale * addressed PR comments. Signed-off-by: Mitchell Gale * Replacing removed include in spotless. Signed-off-by: Mitchell Gale --------- Signed-off-by: Mitchell Gale Signed-off-by: Mitchell Gale --- build.gradle | 6 +- core/build.gradle | 4 +- .../opensearch/sql/DataSourceSchemaName.java | 1 - .../sql/analysis/AnalysisContextTest.java | 1 - .../opensearch/sql/analysis/AnalyzerTest.java | 1363 ++++++++--------- .../sql/analysis/AnalyzerTestBase.java | 101 +- .../sql/analysis/ExpressionAnalyzerTest.java | 665 ++++---- .../ExpressionReferenceOptimizerTest.java | 85 +- .../analysis/NamedExpressionAnalyzerTest.java | 13 +- .../sql/analysis/QualifierAnalyzerTest.java | 47 +- .../sql/analysis/SelectAnalyzeTest.java | 43 +- .../SelectExpressionAnalyzerTest.java | 39 +- .../sql/analysis/TypeEnvironmentTest.java | 26 +- .../WindowExpressionAnalyzerTest.java | 45 +- ...ourceSchemaIdentifierNameResolverTest.java | 11 +- .../sql/analysis/symbol/SymbolTableTest.java | 25 +- .../org/opensearch/sql/config/TestConfig.java | 66 +- 17 files changed, 1221 insertions(+), 1320 deletions(-) diff --git a/build.gradle b/build.gradle index 3e75433d83..71f94636b5 100644 --- a/build.gradle +++ b/build.gradle @@ -84,7 +84,11 @@ repositories { spotless { java { target fileTree('.') { - include 'core/src/main/java/org/opensearch/sql/planner/**/*.java', + include 'core/src/main/java/org/opensearch/sql/DataSourceSchemaName.java', + 'core/src/test/java/org/opensearch/sql/data/**/*.java', + 'core/src/test/java/org/opensearch/sql/config/**/*.java', + 'core/src/test/java/org/opensearch/sql/analysis/**/*.java', + 'core/src/main/java/org/opensearch/sql/planner/**/*.java', 'core/src/main/java/org/opensearch/sql/storage/**/*.java', 'core/src/main/java/org/opensearch/sql/utils/**/*.java', 'core/src/main/java/org/opensearch/sql/monitor/**/*.java', diff --git a/core/build.gradle b/core/build.gradle index 89fac623f2..cf7f0b7a1c 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -34,8 +34,10 @@ repositories { mavenCentral() } -checkstyleMain.ignoreFailures = true +// Being ignored as a temporary measure before being removed in favour of +// spotless https://github.com/opensearch-project/sql/issues/1101 checkstyleTest.ignoreFailures = true +checkstyleMain.ignoreFailures = true pitest { targetClasses = ['org.opensearch.sql.*'] diff --git a/core/src/main/java/org/opensearch/sql/DataSourceSchemaName.java b/core/src/main/java/org/opensearch/sql/DataSourceSchemaName.java index 47988097c3..9c9dfa0772 100644 --- a/core/src/main/java/org/opensearch/sql/DataSourceSchemaName.java +++ b/core/src/main/java/org/opensearch/sql/DataSourceSchemaName.java @@ -17,5 +17,4 @@ public class DataSourceSchemaName { private final String dataSourceName; private final String schemaName; - } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalysisContextTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalysisContextTest.java index 0d643aa53f..b052fe47ce 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalysisContextTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalysisContextTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.junit.jupiter.api.Assertions.assertEquals; diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java index 100cfd67af..2f4d6e8ada 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static java.util.Collections.emptyList; @@ -134,17 +133,13 @@ public void filter_relation_with_reserved_qualifiedName() { @Test public void filter_relation_with_invalid_qualifiedName_SemanticCheckException() { - UnresolvedPlan invalidFieldPlan = AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.equalTo( - AstDSL.qualifiedName("_invalid"), - AstDSL.stringLiteral("value")) - ); + UnresolvedPlan invalidFieldPlan = + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.equalTo(AstDSL.qualifiedName("_invalid"), AstDSL.stringLiteral("value"))); SemanticCheckException exception = - assertThrows( - SemanticCheckException.class, - () -> analyze(invalidFieldPlan)); + assertThrows(SemanticCheckException.class, () -> analyze(invalidFieldPlan)); assertEquals( "can't resolve Symbol(namespace=FIELD_NAME, name=_invalid) in type env", exception.getMessage()); @@ -152,15 +147,13 @@ public void filter_relation_with_invalid_qualifiedName_SemanticCheckException() @Test public void filter_relation_with_invalid_qualifiedName_ExpressionEvaluationException() { - UnresolvedPlan typeMismatchPlan = AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.equalTo(AstDSL.qualifiedName("_test"), AstDSL.intLiteral(1)) - ); + UnresolvedPlan typeMismatchPlan = + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.equalTo(AstDSL.qualifiedName("_test"), AstDSL.intLiteral(1))); ExpressionEvaluationException exception = - assertThrows( - ExpressionEvaluationException.class, - () -> analyze(typeMismatchPlan)); + assertThrows(ExpressionEvaluationException.class, () -> analyze(typeMismatchPlan)); assertEquals( "= function expected {[BYTE,BYTE],[SHORT,SHORT],[INTEGER,INTEGER],[LONG,LONG]," + "[FLOAT,FLOAT],[DOUBLE,DOUBLE],[STRING,STRING],[BOOLEAN,BOOLEAN],[DATE,DATE]," @@ -265,8 +258,8 @@ public void filter_relation_with_non_existing_datasource_with_three_parts() { LogicalPlanDSL.relation("test.nonexisting_schema.http_total_requests", table), DSL.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(1)))), AstDSL.filter( - AstDSL.relation(AstDSL.qualifiedName("test", - "nonexisting_schema", "http_total_requests")), + AstDSL.relation( + AstDSL.qualifiedName("test", "nonexisting_schema", "http_total_requests")), AstDSL.equalTo(AstDSL.field("integer_value"), AstDSL.intLiteral(1)))); } @@ -283,73 +276,68 @@ public void filter_relation_with_multiple_tables() { @Test public void analyze_filter_visit_score_function() { - UnresolvedPlan unresolvedPlan = AstDSL.filter( - AstDSL.relation("schema"), - new ScoreFunction( - AstDSL.function("match_phrase_prefix", - AstDSL.unresolvedArg("field", stringLiteral("field_value1")), - AstDSL.unresolvedArg("query", stringLiteral("search query")), - AstDSL.unresolvedArg("boost", stringLiteral("3")) - ), AstDSL.doubleLiteral(1.0)) - ); + UnresolvedPlan unresolvedPlan = + AstDSL.filter( + AstDSL.relation("schema"), + new ScoreFunction( + AstDSL.function( + "match_phrase_prefix", + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), + AstDSL.unresolvedArg("query", stringLiteral("search query")), + AstDSL.unresolvedArg("boost", stringLiteral("3"))), + AstDSL.doubleLiteral(1.0))); assertAnalyzeEqual( LogicalPlanDSL.filter( LogicalPlanDSL.relation("schema", table), DSL.match_phrase_prefix( DSL.namedArgument("field", "field_value1"), DSL.namedArgument("query", "search query"), - DSL.namedArgument("boost", "3.0") - ) - ), - unresolvedPlan - ); + DSL.namedArgument("boost", "3.0"))), + unresolvedPlan); LogicalPlan logicalPlan = analyze(unresolvedPlan); OpenSearchFunctions.OpenSearchFunction relevanceQuery = - (OpenSearchFunctions.OpenSearchFunction)((LogicalFilter) logicalPlan).getCondition(); + (OpenSearchFunctions.OpenSearchFunction) ((LogicalFilter) logicalPlan).getCondition(); assertEquals(true, relevanceQuery.isScoreTracked()); } @Test public void analyze_filter_visit_without_score_function() { - UnresolvedPlan unresolvedPlan = AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("match_phrase_prefix", - AstDSL.unresolvedArg("field", stringLiteral("field_value1")), - AstDSL.unresolvedArg("query", stringLiteral("search query")), - AstDSL.unresolvedArg("boost", stringLiteral("3")) - ) - ); + UnresolvedPlan unresolvedPlan = + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function( + "match_phrase_prefix", + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), + AstDSL.unresolvedArg("query", stringLiteral("search query")), + AstDSL.unresolvedArg("boost", stringLiteral("3")))); assertAnalyzeEqual( LogicalPlanDSL.filter( LogicalPlanDSL.relation("schema", table), DSL.match_phrase_prefix( DSL.namedArgument("field", "field_value1"), DSL.namedArgument("query", "search query"), - DSL.namedArgument("boost", "3") - ) - ), - unresolvedPlan - ); + DSL.namedArgument("boost", "3"))), + unresolvedPlan); LogicalPlan logicalPlan = analyze(unresolvedPlan); OpenSearchFunctions.OpenSearchFunction relevanceQuery = - (OpenSearchFunctions.OpenSearchFunction)((LogicalFilter) logicalPlan).getCondition(); + (OpenSearchFunctions.OpenSearchFunction) ((LogicalFilter) logicalPlan).getCondition(); assertEquals(false, relevanceQuery.isScoreTracked()); } @Test public void analyze_filter_visit_score_function_with_double_boost() { - UnresolvedPlan unresolvedPlan = AstDSL.filter( - AstDSL.relation("schema"), - new ScoreFunction( - AstDSL.function("match_phrase_prefix", - AstDSL.unresolvedArg("field", stringLiteral("field_value1")), - AstDSL.unresolvedArg("query", stringLiteral("search query")), - AstDSL.unresolvedArg("slop", stringLiteral("3")) - ), new Literal(3.0, DataType.DOUBLE) - ) - ); + UnresolvedPlan unresolvedPlan = + AstDSL.filter( + AstDSL.relation("schema"), + new ScoreFunction( + AstDSL.function( + "match_phrase_prefix", + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), + AstDSL.unresolvedArg("query", stringLiteral("search query")), + AstDSL.unresolvedArg("slop", stringLiteral("3"))), + new Literal(3.0, DataType.DOUBLE))); assertAnalyzeEqual( LogicalPlanDSL.filter( @@ -358,44 +346,36 @@ public void analyze_filter_visit_score_function_with_double_boost() { DSL.namedArgument("field", "field_value1"), DSL.namedArgument("query", "search query"), DSL.namedArgument("slop", "3"), - DSL.namedArgument("boost", "3.0") - ) - ), - unresolvedPlan - ); + DSL.namedArgument("boost", "3.0"))), + unresolvedPlan); LogicalPlan logicalPlan = analyze(unresolvedPlan); OpenSearchFunctions.OpenSearchFunction relevanceQuery = - (OpenSearchFunctions.OpenSearchFunction)((LogicalFilter) logicalPlan).getCondition(); + (OpenSearchFunctions.OpenSearchFunction) ((LogicalFilter) logicalPlan).getCondition(); assertEquals(true, relevanceQuery.isScoreTracked()); } @Test public void analyze_filter_visit_score_function_with_unsupported_boost_SemanticCheckException() { - UnresolvedPlan unresolvedPlan = AstDSL.filter( - AstDSL.relation("schema"), - new ScoreFunction( - AstDSL.function("match_phrase_prefix", - AstDSL.unresolvedArg("field", stringLiteral("field_value1")), - AstDSL.unresolvedArg("query", stringLiteral("search query")), - AstDSL.unresolvedArg("boost", stringLiteral("3")) - ), AstDSL.stringLiteral("3.0") - ) - ); + UnresolvedPlan unresolvedPlan = + AstDSL.filter( + AstDSL.relation("schema"), + new ScoreFunction( + AstDSL.function( + "match_phrase_prefix", + AstDSL.unresolvedArg("field", stringLiteral("field_value1")), + AstDSL.unresolvedArg("query", stringLiteral("search query")), + AstDSL.unresolvedArg("boost", stringLiteral("3"))), + AstDSL.stringLiteral("3.0"))); SemanticCheckException exception = - assertThrows( - SemanticCheckException.class, - () -> analyze(unresolvedPlan)); - assertEquals( - "Expected boost type 'DOUBLE' but got 'STRING'", - exception.getMessage()); + assertThrows(SemanticCheckException.class, () -> analyze(unresolvedPlan)); + assertEquals("Expected boost type 'DOUBLE' but got 'STRING'", exception.getMessage()); } @Test public void head_relation() { assertAnalyzeEqual( - LogicalPlanDSL.limit(LogicalPlanDSL.relation("schema", table), - 10, 0), + LogicalPlanDSL.limit(LogicalPlanDSL.relation("schema", table), 10, 0), AstDSL.head(AstDSL.relation("schema"), 10, 0)); } @@ -418,7 +398,7 @@ public void analyze_filter_aggregation_relation() { DSL.named("AVG(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER))), DSL.named("MIN(integer_value)", DSL.min(DSL.ref("integer_value", INTEGER)))), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))), - DSL.greater(// Expect to be replaced with reference by expression optimizer + DSL.greater( // Expect to be replaced with reference by expression optimizer DSL.ref("MIN(integer_value)", INTEGER), DSL.literal(integerValue(10)))), AstDSL.filter( AstDSL.agg( @@ -429,8 +409,7 @@ public void analyze_filter_aggregation_relation() { emptyList(), ImmutableList.of(alias("string_value", qualifiedName("string_value"))), emptyList()), - compare(">", - aggregate("MIN", qualifiedName("integer_value")), intLiteral(10)))); + compare(">", aggregate("MIN", qualifiedName("integer_value")), intLiteral(10)))); } @Test @@ -449,19 +428,16 @@ public void stats_source() { assertAnalyzeEqual( LogicalPlanDSL.aggregation( LogicalPlanDSL.relation("schema", table), - ImmutableList - .of(DSL.named("avg(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), + ImmutableList.of( + DSL.named("avg(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))), AstDSL.agg( AstDSL.relation("schema"), AstDSL.exprList( AstDSL.alias( - "avg(integer_value)", - AstDSL.aggregate("avg", field("integer_value"))) - ), + "avg(integer_value)", AstDSL.aggregate("avg", field("integer_value")))), null, - ImmutableList.of( - AstDSL.alias("string_value", field("string_value"))), + ImmutableList.of(AstDSL.alias("string_value", field("string_value"))), AstDSL.defaultStatsArgs())); } @@ -473,16 +449,13 @@ public void rare_source() { CommandType.RARE, 10, ImmutableList.of(DSL.ref("string_value", STRING)), - DSL.ref("integer_value", INTEGER) - ), + DSL.ref("integer_value", INTEGER)), AstDSL.rareTopN( AstDSL.relation("schema"), CommandType.RARE, ImmutableList.of(argument("noOfResults", intLiteral(10))), ImmutableList.of(field("string_value")), - field("integer_value") - ) - ); + field("integer_value"))); } @Test @@ -493,16 +466,13 @@ public void top_source() { CommandType.TOP, 5, ImmutableList.of(DSL.ref("string_value", STRING)), - DSL.ref("integer_value", INTEGER) - ), + DSL.ref("integer_value", INTEGER)), AstDSL.rareTopN( AstDSL.relation("schema"), CommandType.TOP, ImmutableList.of(argument("noOfResults", intLiteral(5))), ImmutableList.of(field("string_value")), - field("integer_value") - ) - ); + field("integer_value"))); } @Test @@ -516,8 +486,9 @@ public void rename_to_invalid_expression() { AstDSL.agg( AstDSL.relation("schema"), AstDSL.exprList( - AstDSL.alias("avg(integer_value)", AstDSL.aggregate("avg", field( - "integer_value")))), + AstDSL.alias( + "avg(integer_value)", + AstDSL.aggregate("avg", field("integer_value")))), Collections.emptyList(), ImmutableList.of(), AstDSL.defaultStatsArgs()), @@ -535,8 +506,7 @@ public void project_source() { LogicalPlanDSL.project( LogicalPlanDSL.relation("schema", table), DSL.named("integer_value", DSL.ref("integer_value", INTEGER)), - DSL.named("double_value", DSL.ref("double_value", DOUBLE)) - ), + DSL.named("double_value", DSL.ref("double_value", DOUBLE))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), @@ -550,34 +520,25 @@ public void project_nested_field_arg() { List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), - "path", new ReferenceExpression("message", STRING) - ) - ); + "path", new ReferenceExpression("message", STRING))); List projectList = List.of( new NamedExpression( - "nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING)), - null) - ); + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING)), null)); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.nested( - LogicalPlanDSL.relation("schema", table), - nestedArgs, - projectList), - DSL.named("nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING))) - ), + LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING)))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("nested(message.info)", - function("nested", qualifiedName("message", "info")), null) - ) - ); + AstDSL.alias( + "nested(message.info)", + function("nested", qualifiedName("message", "info")), + null))); assertTrue(isNestedFunction(DSL.nested(DSL.ref("message.info", STRING)))); assertFalse(isNestedFunction(DSL.literal("fieldA"))); @@ -586,64 +547,51 @@ public void project_nested_field_arg() { @Test public void sort_with_nested_all_tuple_fields_throws_exception() { - assertThrows(UnsupportedOperationException.class, () -> analyze( - AstDSL.project( - AstDSL.sort( - AstDSL.relation("schema"), - field(nestedAllTupleFields("message")) - ), - AstDSL.alias("nested(message.*)", - nestedAllTupleFields("message")) - ) - )); + assertThrows( + UnsupportedOperationException.class, + () -> + analyze( + AstDSL.project( + AstDSL.sort(AstDSL.relation("schema"), field(nestedAllTupleFields("message"))), + AstDSL.alias("nested(message.*)", nestedAllTupleFields("message"))))); } @Test public void filter_with_nested_all_tuple_fields_throws_exception() { - assertThrows(UnsupportedOperationException.class, () -> analyze( - AstDSL.project( - AstDSL.filter( - AstDSL.relation("schema"), - AstDSL.function("=", nestedAllTupleFields("message"), AstDSL.intLiteral(1))), - AstDSL.alias("nested(message.*)", - nestedAllTupleFields("message")) - ) - )); + assertThrows( + UnsupportedOperationException.class, + () -> + analyze( + AstDSL.project( + AstDSL.filter( + AstDSL.relation("schema"), + AstDSL.function( + "=", nestedAllTupleFields("message"), AstDSL.intLiteral(1))), + AstDSL.alias("nested(message.*)", nestedAllTupleFields("message"))))); } - @Test public void project_nested_field_star_arg() { List> nestedArgs = List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), - "path", new ReferenceExpression("message", STRING) - ) - ); + "path", new ReferenceExpression("message", STRING))); List projectList = List.of( - new NamedExpression("nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING))) - ); + new NamedExpression( + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING)))); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.nested( - LogicalPlanDSL.relation("schema", table), - nestedArgs, - projectList), - DSL.named("nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING))) - ), + LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING)))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("nested(message.*)", - nestedAllTupleFields("message")) - ) - ); + AstDSL.alias("nested(message.*)", nestedAllTupleFields("message")))); } @Test @@ -652,42 +600,29 @@ public void project_nested_field_star_arg_with_another_nested_function() { List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), - "path", new ReferenceExpression("message", STRING) - ), + "path", new ReferenceExpression("message", STRING)), Map.of( "field", new ReferenceExpression("comment.data", STRING), - "path", new ReferenceExpression("comment", STRING) - ) - ); + "path", new ReferenceExpression("comment", STRING))); List projectList = List.of( - new NamedExpression("nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING))), - new NamedExpression("nested(comment.data)", - DSL.nested(DSL.ref("comment.data", STRING))) - ); + new NamedExpression( + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING))), + new NamedExpression( + "nested(comment.data)", DSL.nested(DSL.ref("comment.data", STRING)))); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.nested( - LogicalPlanDSL.relation("schema", table), - nestedArgs, - projectList), - DSL.named("nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING))), - DSL.named("nested(comment.data)", - DSL.nested(DSL.ref("comment.data", STRING))) - ), + LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("nested(comment.data)", DSL.nested(DSL.ref("comment.data", STRING)))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("nested(message.*)", - nestedAllTupleFields("message")), - AstDSL.alias("nested(comment.*)", - nestedAllTupleFields("comment")) - ) - ); + AstDSL.alias("nested(message.*)", nestedAllTupleFields("message")), + AstDSL.alias("nested(comment.*)", nestedAllTupleFields("comment")))); } @Test @@ -696,38 +631,25 @@ public void project_nested_field_star_arg_with_another_field() { List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), - "path", new ReferenceExpression("message", STRING) - ) - ); + "path", new ReferenceExpression("message", STRING))); List projectList = List.of( - new NamedExpression("nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING))), - new NamedExpression("comment.data", - DSL.ref("comment.data", STRING)) - ); + new NamedExpression( + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING))), + new NamedExpression("comment.data", DSL.ref("comment.data", STRING))); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.nested( - LogicalPlanDSL.relation("schema", table), - nestedArgs, - projectList), - DSL.named("nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING))), - DSL.named("comment.data", - DSL.ref("comment.data", STRING)) - ), + LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("comment.data", DSL.ref("comment.data", STRING))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("nested(message.*)", - nestedAllTupleFields("message")), - AstDSL.alias("comment.data", - field("comment.data")) - ) - ); + AstDSL.alias("nested(message.*)", nestedAllTupleFields("message")), + AstDSL.alias("comment.data", field("comment.data")))); } @Test @@ -736,41 +658,32 @@ public void project_nested_field_star_arg_with_highlight() { List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), - "path", new ReferenceExpression("message", STRING) - ) - ); + "path", new ReferenceExpression("message", STRING))); List projectList = List.of( - new NamedExpression("nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING))), - DSL.named("highlight(fieldA)", - new HighlightExpression(DSL.literal("fieldA"))) - ); + new NamedExpression( + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("highlight(fieldA)", new HighlightExpression(DSL.literal("fieldA")))); Map highlightArgs = new HashMap<>(); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.nested( - LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), - DSL.literal("fieldA"), highlightArgs), + LogicalPlanDSL.highlight( + LogicalPlanDSL.relation("schema", table), DSL.literal("fieldA"), highlightArgs), nestedArgs, projectList), - DSL.named("nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING))), - DSL.named("highlight(fieldA)", - new HighlightExpression(DSL.literal("fieldA"))) - ), + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("highlight(fieldA)", new HighlightExpression(DSL.literal("fieldA")))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("nested(message.*)", - nestedAllTupleFields("message")), - AstDSL.alias("highlight(fieldA)", - new HighlightFunction(AstDSL.stringLiteral("fieldA"), highlightArgs)) - ) - ); + AstDSL.alias("nested(message.*)", nestedAllTupleFields("message")), + AstDSL.alias( + "highlight(fieldA)", + new HighlightFunction(AstDSL.stringLiteral("fieldA"), highlightArgs)))); } @Test @@ -779,40 +692,29 @@ public void project_nested_field_and_path_args() { List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), - "path", new ReferenceExpression("message", STRING) - ) - ); + "path", new ReferenceExpression("message", STRING))); List projectList = List.of( new NamedExpression( "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING)), - null) - ); + null)); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.nested( - LogicalPlanDSL.relation("schema", table), - nestedArgs, - projectList), - DSL.named("nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING))) - ), + LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), + DSL.named( + "nested(message.info)", + DSL.nested(DSL.ref("message.info", STRING), DSL.ref("message", STRING)))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("nested(message.info)", - function( - "nested", - qualifiedName("message", "info"), - qualifiedName("message") - ), - null - ) - ) - ); + AstDSL.alias( + "nested(message.info)", + function("nested", qualifiedName("message", "info"), qualifiedName("message")), + null))); } @Test @@ -821,34 +723,25 @@ public void project_nested_deep_field_arg() { List.of( Map.of( "field", new ReferenceExpression("message.info.id", STRING), - "path", new ReferenceExpression("message.info", STRING) - ) - ); + "path", new ReferenceExpression("message.info", STRING))); List projectList = List.of( new NamedExpression( - "nested(message.info.id)", - DSL.nested(DSL.ref("message.info.id", STRING)), - null) - ); + "nested(message.info.id)", DSL.nested(DSL.ref("message.info.id", STRING)), null)); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.nested( - LogicalPlanDSL.relation("schema", table), - nestedArgs, - projectList), - DSL.named("nested(message.info.id)", - DSL.nested(DSL.ref("message.info.id", STRING))) - ), + LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), + DSL.named("nested(message.info.id)", DSL.nested(DSL.ref("message.info.id", STRING)))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("nested(message.info.id)", - function("nested", qualifiedName("message", "info", "id")), null) - ) - ); + AstDSL.alias( + "nested(message.info.id)", + function("nested", qualifiedName("message", "info", "id")), + null))); } @Test @@ -857,114 +750,102 @@ public void project_multiple_nested() { List.of( Map.of( "field", new ReferenceExpression("message.info", STRING), - "path", new ReferenceExpression("message", STRING) - ), + "path", new ReferenceExpression("message", STRING)), Map.of( "field", new ReferenceExpression("comment.data", STRING), - "path", new ReferenceExpression("comment", STRING) - ) - ); + "path", new ReferenceExpression("comment", STRING))); List projectList = List.of( new NamedExpression( - "nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING)), - null), + "nested(message.info)", DSL.nested(DSL.ref("message.info", STRING)), null), new NamedExpression( - "nested(comment.data)", - DSL.nested(DSL.ref("comment.data", STRING)), - null) - ); + "nested(comment.data)", DSL.nested(DSL.ref("comment.data", STRING)), null)); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.nested( - LogicalPlanDSL.relation("schema", table), - nestedArgs, - projectList), - DSL.named("nested(message.info)", - DSL.nested(DSL.ref("message.info", STRING))), - DSL.named("nested(comment.data)", - DSL.nested(DSL.ref("comment.data", STRING))) - ), + LogicalPlanDSL.relation("schema", table), nestedArgs, projectList), + DSL.named("nested(message.info)", DSL.nested(DSL.ref("message.info", STRING))), + DSL.named("nested(comment.data)", DSL.nested(DSL.ref("comment.data", STRING)))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("nested(message.info)", - function("nested", qualifiedName("message", "info")), null), - AstDSL.alias("nested(comment.data)", - function("nested", qualifiedName("comment", "data")), null) - ) - ); + AstDSL.alias( + "nested(message.info)", function("nested", qualifiedName("message", "info")), null), + AstDSL.alias( + "nested(comment.data)", + function("nested", qualifiedName("comment", "data")), + null))); } @Test public void project_nested_invalid_field_throws_exception() { - var exception = assertThrows( - IllegalArgumentException.class, - () -> analyze(AstDSL.projectWithArg( - AstDSL.relation("schema"), - AstDSL.defaultFieldsArgs(), - AstDSL.alias("message", - function("nested", qualifiedName("message")), null) - ) - ) - ); + var exception = + assertThrows( + IllegalArgumentException.class, + () -> + analyze( + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias( + "message", function("nested", qualifiedName("message")), null)))); assertEquals(exception.getMessage(), "Illegal nested field name: message"); } @Test public void project_nested_invalid_arg_type_throws_exception() { - var exception = assertThrows( - IllegalArgumentException.class, - () -> analyze(AstDSL.projectWithArg( - AstDSL.relation("schema"), - AstDSL.defaultFieldsArgs(), - AstDSL.alias("message", - function("nested", stringLiteral("message")), null) - ) - ) - ); + var exception = + assertThrows( + IllegalArgumentException.class, + () -> + analyze( + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias( + "message", function("nested", stringLiteral("message")), null)))); assertEquals(exception.getMessage(), "Illegal nested field name: message"); } @Test public void project_nested_no_args_throws_exception() { - var exception = assertThrows( - IllegalArgumentException.class, - () -> analyze(AstDSL.projectWithArg( - AstDSL.relation("schema"), - AstDSL.defaultFieldsArgs(), - AstDSL.alias("message", - function("nested"), null) - ) - ) - ); - assertEquals(exception.getMessage(), - "on nested object only allowed 2 parameters (field,path) or 1 parameter (field)" - ); + var exception = + assertThrows( + IllegalArgumentException.class, + () -> + analyze( + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias("message", function("nested"), null)))); + assertEquals( + exception.getMessage(), + "on nested object only allowed 2 parameters (field,path) or 1 parameter (field)"); } @Test public void project_nested_too_many_args_throws_exception() { - var exception = assertThrows( - IllegalArgumentException.class, - () -> analyze(AstDSL.projectWithArg( - AstDSL.relation("schema"), - AstDSL.defaultFieldsArgs(), - AstDSL.alias("message", - function("nested", - stringLiteral("message.info"), - stringLiteral("message"), - stringLiteral("message")), - null) - ) - ) - ); - assertEquals(exception.getMessage(), - "on nested object only allowed 2 parameters (field,path) or 1 parameter (field)" - ); + var exception = + assertThrows( + IllegalArgumentException.class, + () -> + analyze( + AstDSL.projectWithArg( + AstDSL.relation("schema"), + AstDSL.defaultFieldsArgs(), + AstDSL.alias( + "message", + function( + "nested", + stringLiteral("message.info"), + stringLiteral("message"), + stringLiteral("message")), + null)))); + assertEquals( + exception.getMessage(), + "on nested object only allowed 2 parameters (field,path) or 1 parameter (field)"); } @Test @@ -975,18 +856,17 @@ public void project_highlight() { assertAnalyzeEqual( LogicalPlanDSL.project( - LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), - DSL.literal("fieldA"), args), - DSL.named("highlight(fieldA, pre_tags='', post_tags='')", - new HighlightExpression(DSL.literal("fieldA"))) - ), + LogicalPlanDSL.highlight( + LogicalPlanDSL.relation("schema", table), DSL.literal("fieldA"), args), + DSL.named( + "highlight(fieldA, pre_tags='', post_tags='')", + new HighlightExpression(DSL.literal("fieldA")))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("highlight(fieldA, pre_tags='', post_tags='')", - new HighlightFunction(AstDSL.stringLiteral("fieldA"), args)) - ) - ); + AstDSL.alias( + "highlight(fieldA, pre_tags='', post_tags='')", + new HighlightFunction(AstDSL.stringLiteral("fieldA"), args)))); } @Test @@ -994,18 +874,13 @@ public void project_highlight_wildcard() { Map args = new HashMap<>(); assertAnalyzeEqual( LogicalPlanDSL.project( - LogicalPlanDSL.highlight(LogicalPlanDSL.relation("schema", table), - DSL.literal("*"), args), - DSL.named("highlight(*)", - new HighlightExpression(DSL.literal("*"))) - ), + LogicalPlanDSL.highlight( + LogicalPlanDSL.relation("schema", table), DSL.literal("*"), args), + DSL.named("highlight(*)", new HighlightExpression(DSL.literal("*")))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), - AstDSL.alias("highlight(*)", - new HighlightFunction(AstDSL.stringLiteral("*"), args)) - ) - ); + AstDSL.alias("highlight(*)", new HighlightFunction(AstDSL.stringLiteral("*"), args)))); } @Test @@ -1013,8 +888,8 @@ public void remove_source() { assertAnalyzeEqual( LogicalPlanDSL.remove( LogicalPlanDSL.relation("schema", table), - DSL.ref("integer_value", INTEGER), DSL.ref( - "double_value", DOUBLE)), + DSL.ref("integer_value", INTEGER), + DSL.ref("double_value", DOUBLE)), AstDSL.projectWithArg( AstDSL.relation("schema"), Collections.singletonList(argument("exclude", booleanLiteral(true))), @@ -1022,7 +897,8 @@ public void remove_source() { AstDSL.field("double_value"))); } - @Disabled("the project/remove command should shrink the type env. Should be enabled once " + @Disabled( + "the project/remove command should shrink the type env. Should be enabled once " + "https://github.com/opensearch-project/sql/issues/917 is resolved") @Test public void project_source_change_type_env() { @@ -1048,15 +924,12 @@ public void project_values() { LogicalPlanDSL.values(ImmutableList.of(DSL.literal(123))), DSL.named("123", DSL.literal(123)), DSL.named("hello", DSL.literal("hello")), - DSL.named("false", DSL.literal(false)) - ), + DSL.named("false", DSL.literal(false))), AstDSL.project( AstDSL.values(ImmutableList.of(AstDSL.intLiteral(123))), AstDSL.alias("123", AstDSL.intLiteral(123)), AstDSL.alias("hello", AstDSL.stringLiteral("hello")), - AstDSL.alias("false", AstDSL.booleanLiteral(false)) - ) - ); + AstDSL.alias("false", AstDSL.booleanLiteral(false)))); } @SuppressWarnings("unchecked") @@ -1069,8 +942,7 @@ public void sort_with_aggregator() { LogicalPlanDSL.relation("test", table), ImmutableList.of( DSL.named( - "avg(integer_value)", - DSL.avg(DSL.ref("integer_value", INTEGER)))), + "avg(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))), // Aggregator in Sort AST node is replaced with reference by expression optimizer Pair.of(SortOption.DEFAULT_ASC, DSL.ref("avg(integer_value)", DOUBLE))), @@ -1081,12 +953,10 @@ public void sort_with_aggregator() { AstDSL.relation("test"), ImmutableList.of( AstDSL.alias( - "avg(integer_value)", - function("avg", qualifiedName("integer_value")))), + "avg(integer_value)", function("avg", qualifiedName("integer_value")))), emptyList(), ImmutableList.of(AstDSL.alias("string_value", qualifiedName("string_value"))), - emptyList() - ), + emptyList()), field( function("avg", qualifiedName("integer_value")), argument("asc", booleanLiteral(true)))), @@ -1098,40 +968,49 @@ public void sort_with_aggregator() { public void sort_with_options() { ImmutableMap argOptions = ImmutableMap.builder() - .put(new Argument[] {argument("asc", booleanLiteral(true))}, + .put( + new Argument[] {argument("asc", booleanLiteral(true))}, new SortOption(SortOrder.ASC, NullOrder.NULL_FIRST)) - .put(new Argument[] {argument("asc", booleanLiteral(false))}, + .put( + new Argument[] {argument("asc", booleanLiteral(false))}, new SortOption(SortOrder.DESC, NullOrder.NULL_LAST)) - .put(new Argument[] { - argument("asc", booleanLiteral(true)), - argument("nullFirst", booleanLiteral(true))}, + .put( + new Argument[] { + argument("asc", booleanLiteral(true)), argument("nullFirst", booleanLiteral(true)) + }, new SortOption(SortOrder.ASC, NullOrder.NULL_FIRST)) - .put(new Argument[] { - argument("asc", booleanLiteral(true)), - argument("nullFirst", booleanLiteral(false))}, + .put( + new Argument[] { + argument("asc", booleanLiteral(true)), + argument("nullFirst", booleanLiteral(false)) + }, new SortOption(SortOrder.ASC, NullOrder.NULL_LAST)) - .put(new Argument[] { - argument("asc", booleanLiteral(false)), - argument("nullFirst", booleanLiteral(true))}, + .put( + new Argument[] { + argument("asc", booleanLiteral(false)), + argument("nullFirst", booleanLiteral(true)) + }, new SortOption(SortOrder.DESC, NullOrder.NULL_FIRST)) - .put(new Argument[] { - argument("asc", booleanLiteral(false)), - argument("nullFirst", booleanLiteral(false))}, + .put( + new Argument[] { + argument("asc", booleanLiteral(false)), + argument("nullFirst", booleanLiteral(false)) + }, new SortOption(SortOrder.DESC, NullOrder.NULL_LAST)) .build(); - argOptions.forEach((args, expectOption) -> - assertAnalyzeEqual( - LogicalPlanDSL.project( - LogicalPlanDSL.sort( - LogicalPlanDSL.relation("test", table), - Pair.of(expectOption, DSL.ref("integer_value", INTEGER))), - DSL.named("string_value", DSL.ref("string_value", STRING))), - AstDSL.project( - AstDSL.sort( - AstDSL.relation("test"), - field(qualifiedName("integer_value"), args)), - AstDSL.alias("string_value", qualifiedName("string_value"))))); + argOptions.forEach( + (args, expectOption) -> + assertAnalyzeEqual( + LogicalPlanDSL.project( + LogicalPlanDSL.sort( + LogicalPlanDSL.relation("test", table), + Pair.of(expectOption, DSL.ref("integer_value", INTEGER))), + DSL.named("string_value", DSL.ref("string_value", STRING))), + AstDSL.project( + AstDSL.sort( + AstDSL.relation("test"), field(qualifiedName("integer_value"), args)), + AstDSL.alias("string_value", qualifiedName("string_value"))))); } @SuppressWarnings("unchecked") @@ -1156,7 +1035,8 @@ public void window_function() { AstDSL.project( AstDSL.relation("test"), AstDSL.alias("string_value", AstDSL.qualifiedName("string_value")), - AstDSL.alias("window_function", + AstDSL.alias( + "window_function", AstDSL.window( AstDSL.function("row_number"), Collections.singletonList(AstDSL.qualifiedName("string_value")), @@ -1164,11 +1044,7 @@ public void window_function() { ImmutablePair.of(DEFAULT_ASC, AstDSL.qualifiedName("integer_value"))))))); } - /** - * SELECT name FROM ( - * SELECT name, age FROM test - * ) AS schema. - */ + /** SELECT name FROM ( SELECT name, age FROM test ) AS schema. */ @Test public void from_subquery() { assertAnalyzeEqual( @@ -1176,29 +1052,19 @@ public void from_subquery() { LogicalPlanDSL.project( LogicalPlanDSL.relation("schema", table), DSL.named("string_value", DSL.ref("string_value", STRING)), - DSL.named("integer_value", DSL.ref("integer_value", INTEGER)) - ), - DSL.named("string_value", DSL.ref("string_value", STRING)) - ), + DSL.named("integer_value", DSL.ref("integer_value", INTEGER))), + DSL.named("string_value", DSL.ref("string_value", STRING))), AstDSL.project( AstDSL.relationSubquery( AstDSL.project( AstDSL.relation("schema"), AstDSL.alias("string_value", AstDSL.qualifiedName("string_value")), - AstDSL.alias("integer_value", AstDSL.qualifiedName("integer_value")) - ), - "schema" - ), - AstDSL.alias("string_value", AstDSL.qualifiedName("string_value")) - ) - ); + AstDSL.alias("integer_value", AstDSL.qualifiedName("integer_value"))), + "schema"), + AstDSL.alias("string_value", AstDSL.qualifiedName("string_value")))); } - /** - * SELECT * FROM ( - * SELECT name FROM test - * ) AS schema. - */ + /** SELECT * FROM ( SELECT name FROM test ) AS schema. */ @Test public void select_all_from_subquery() { assertAnalyzeEqual( @@ -1206,147 +1072,130 @@ public void select_all_from_subquery() { LogicalPlanDSL.project( LogicalPlanDSL.relation("schema", table), DSL.named("string_value", DSL.ref("string_value", STRING))), - DSL.named("string_value", DSL.ref("string_value", STRING)) - ), + DSL.named("string_value", DSL.ref("string_value", STRING))), AstDSL.project( AstDSL.relationSubquery( AstDSL.project( AstDSL.relation("schema"), - AstDSL.alias("string_value", AstDSL.qualifiedName("string_value")) - ), - "schema" - ), - AstDSL.allFields() - ) - ); + AstDSL.alias("string_value", AstDSL.qualifiedName("string_value"))), + "schema"), + AstDSL.allFields())); } /** - * Ensure Nested function falls back to legacy engine when used in GROUP BY clause. - * TODO Remove this test when support is added. + * Ensure Nested function falls back to legacy engine when used in GROUP BY clause. TODO Remove + * this test when support is added. */ @Test public void nested_group_by_clause_throws_syntax_exception() { - SyntaxCheckException exception = assertThrows(SyntaxCheckException.class, - () -> analyze( - AstDSL.project( - AstDSL.agg( - AstDSL.relation("schema"), - emptyList(), - emptyList(), - ImmutableList.of(alias("nested(message.info)", - function("nested", - qualifiedName("message", "info")))), - emptyList() - ))) - ); - assertEquals("Falling back to legacy engine. Nested function is not supported in WHERE," + SyntaxCheckException exception = + assertThrows( + SyntaxCheckException.class, + () -> + analyze( + AstDSL.project( + AstDSL.agg( + AstDSL.relation("schema"), + emptyList(), + emptyList(), + ImmutableList.of( + alias( + "nested(message.info)", + function("nested", qualifiedName("message", "info")))), + emptyList())))); + assertEquals( + "Falling back to legacy engine. Nested function is not supported in WHERE," + " GROUP BY, and HAVING clauses.", exception.getMessage()); } - /** - * SELECT name, AVG(age) FROM test GROUP BY name. - */ + /** SELECT name, AVG(age) FROM test GROUP BY name. */ @Test public void sql_group_by_field() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( LogicalPlanDSL.relation("schema", table), - ImmutableList - .of(DSL - .named("AVG(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), + ImmutableList.of( + DSL.named("AVG(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))), DSL.named("string_value", DSL.ref("string_value", STRING)), DSL.named("AVG(integer_value)", DSL.ref("AVG(integer_value)", DOUBLE))), AstDSL.project( AstDSL.agg( AstDSL.relation("schema"), - ImmutableList.of(alias("AVG(integer_value)", - aggregate("AVG", qualifiedName("integer_value")))), + ImmutableList.of( + alias("AVG(integer_value)", aggregate("AVG", qualifiedName("integer_value")))), emptyList(), ImmutableList.of(alias("string_value", qualifiedName("string_value"))), emptyList()), AstDSL.alias("string_value", qualifiedName("string_value")), - AstDSL.alias("AVG(integer_value)", aggregate("AVG", qualifiedName("integer_value")))) - ); + AstDSL.alias("AVG(integer_value)", aggregate("AVG", qualifiedName("integer_value"))))); } - /** - * SELECT abs(name), AVG(age) FROM test GROUP BY abs(name). - */ + /** SELECT abs(name), AVG(age) FROM test GROUP BY abs(name). */ @Test public void sql_group_by_function() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( LogicalPlanDSL.relation("schema", table), - ImmutableList - .of(DSL - .named("AVG(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), - ImmutableList.of(DSL.named("abs(long_value)", - DSL.abs(DSL.ref("long_value", LONG))))), + ImmutableList.of( + DSL.named("AVG(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), + ImmutableList.of( + DSL.named("abs(long_value)", DSL.abs(DSL.ref("long_value", LONG))))), DSL.named("abs(long_value)", DSL.ref("abs(long_value)", LONG)), DSL.named("AVG(integer_value)", DSL.ref("AVG(integer_value)", DOUBLE))), AstDSL.project( AstDSL.agg( AstDSL.relation("schema"), - ImmutableList.of(alias("AVG(integer_value)", - aggregate("AVG", qualifiedName("integer_value")))), + ImmutableList.of( + alias("AVG(integer_value)", aggregate("AVG", qualifiedName("integer_value")))), emptyList(), - ImmutableList - .of(alias("abs(long_value)", function("abs", qualifiedName("long_value")))), + ImmutableList.of( + alias("abs(long_value)", function("abs", qualifiedName("long_value")))), emptyList()), AstDSL.alias("abs(long_value)", function("abs", qualifiedName("long_value"))), - AstDSL.alias("AVG(integer_value)", aggregate("AVG", qualifiedName("integer_value")))) - ); + AstDSL.alias("AVG(integer_value)", aggregate("AVG", qualifiedName("integer_value"))))); } - /** - * SELECT abs(name), AVG(age) FROM test GROUP BY ABS(name). - */ + /** SELECT abs(name), AVG(age) FROM test GROUP BY ABS(name). */ @Test public void sql_group_by_function_in_uppercase() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( LogicalPlanDSL.relation("schema", table), - ImmutableList - .of(DSL - .named("AVG(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), - ImmutableList.of(DSL.named("ABS(long_value)", - DSL.abs(DSL.ref("long_value", LONG))))), + ImmutableList.of( + DSL.named("AVG(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), + ImmutableList.of( + DSL.named("ABS(long_value)", DSL.abs(DSL.ref("long_value", LONG))))), DSL.named("abs(long_value)", DSL.ref("ABS(long_value)", LONG)), DSL.named("AVG(integer_value)", DSL.ref("AVG(integer_value)", DOUBLE))), AstDSL.project( AstDSL.agg( AstDSL.relation("schema"), - ImmutableList.of(alias("AVG(integer_value)", - aggregate("AVG", qualifiedName("integer_value")))), + ImmutableList.of( + alias("AVG(integer_value)", aggregate("AVG", qualifiedName("integer_value")))), emptyList(), - ImmutableList - .of(alias("ABS(long_value)", function("ABS", qualifiedName("long_value")))), + ImmutableList.of( + alias("ABS(long_value)", function("ABS", qualifiedName("long_value")))), emptyList()), AstDSL.alias("abs(long_value)", function("abs", qualifiedName("long_value"))), - AstDSL.alias("AVG(integer_value)", aggregate("AVG", qualifiedName("integer_value")))) - ); + AstDSL.alias("AVG(integer_value)", aggregate("AVG", qualifiedName("integer_value"))))); } - /** - * SELECT abs(name), abs(avg(age) FROM test GROUP BY abs(name). - */ + /** SELECT abs(name), abs(avg(age) FROM test GROUP BY abs(name). */ @Test public void sql_expression_over_one_aggregation() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( LogicalPlanDSL.relation("schema", table), - ImmutableList - .of(DSL.named("avg(integer_value)", - DSL.avg(DSL.ref("integer_value", INTEGER)))), - ImmutableList.of(DSL.named("abs(long_value)", - DSL.abs(DSL.ref("long_value", LONG))))), + ImmutableList.of( + DSL.named("avg(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), + ImmutableList.of( + DSL.named("abs(long_value)", DSL.abs(DSL.ref("long_value", LONG))))), DSL.named("abs(long_value)", DSL.ref("abs(long_value)", LONG)), DSL.named("abs(avg(integer_value)", DSL.abs(DSL.ref("avg(integer_value)", DOUBLE)))), AstDSL.project( @@ -1355,34 +1204,32 @@ public void sql_expression_over_one_aggregation() { ImmutableList.of( alias("avg(integer_value)", aggregate("avg", qualifiedName("integer_value")))), emptyList(), - ImmutableList - .of(alias("abs(long_value)", function("abs", qualifiedName("long_value")))), + ImmutableList.of( + alias("abs(long_value)", function("abs", qualifiedName("long_value")))), emptyList()), AstDSL.alias("abs(long_value)", function("abs", qualifiedName("long_value"))), - AstDSL.alias("abs(avg(integer_value)", - function("abs", aggregate("avg", qualifiedName("integer_value"))))) - ); + AstDSL.alias( + "abs(avg(integer_value)", + function("abs", aggregate("avg", qualifiedName("integer_value")))))); } - /** - * SELECT abs(name), sum(age)-avg(age) FROM test GROUP BY abs(name). - */ + /** SELECT abs(name), sum(age)-avg(age) FROM test GROUP BY abs(name). */ @Test public void sql_expression_over_two_aggregation() { assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.aggregation( LogicalPlanDSL.relation("schema", table), - ImmutableList - .of(DSL.named("sum(integer_value)", - DSL.sum(DSL.ref("integer_value", INTEGER))), - DSL.named("avg(integer_value)", - DSL.avg(DSL.ref("integer_value", INTEGER)))), - ImmutableList.of(DSL.named("abs(long_value)", - DSL.abs(DSL.ref("long_value", LONG))))), + ImmutableList.of( + DSL.named("sum(integer_value)", DSL.sum(DSL.ref("integer_value", INTEGER))), + DSL.named("avg(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), + ImmutableList.of( + DSL.named("abs(long_value)", DSL.abs(DSL.ref("long_value", LONG))))), DSL.named("abs(long_value)", DSL.ref("abs(long_value)", LONG)), - DSL.named("sum(integer_value)-avg(integer_value)", - DSL.subtract(DSL.ref("sum(integer_value)", INTEGER), + DSL.named( + "sum(integer_value)-avg(integer_value)", + DSL.subtract( + DSL.ref("sum(integer_value)", INTEGER), DSL.ref("avg(integer_value)", DOUBLE)))), AstDSL.project( AstDSL.agg( @@ -1391,40 +1238,33 @@ public void sql_expression_over_two_aggregation() { alias("sum(integer_value)", aggregate("sum", qualifiedName("integer_value"))), alias("avg(integer_value)", aggregate("avg", qualifiedName("integer_value")))), emptyList(), - ImmutableList - .of(alias("abs(long_value)", function("abs", qualifiedName("long_value")))), + ImmutableList.of( + alias("abs(long_value)", function("abs", qualifiedName("long_value")))), emptyList()), AstDSL.alias("abs(long_value)", function("abs", qualifiedName("long_value"))), - AstDSL.alias("sum(integer_value)-avg(integer_value)", - function("-", aggregate("sum", qualifiedName("integer_value")), - aggregate("avg", qualifiedName("integer_value"))))) - ); + AstDSL.alias( + "sum(integer_value)-avg(integer_value)", + function( + "-", + aggregate("sum", qualifiedName("integer_value")), + aggregate("avg", qualifiedName("integer_value")))))); } @Test public void limit_offset() { assertAnalyzeEqual( LogicalPlanDSL.project( - LogicalPlanDSL.limit( - LogicalPlanDSL.relation("schema", table), - 1, 1 - ), - DSL.named("integer_value", DSL.ref("integer_value", INTEGER)) - ), + LogicalPlanDSL.limit(LogicalPlanDSL.relation("schema", table), 1, 1), + DSL.named("integer_value", DSL.ref("integer_value", INTEGER))), AstDSL.project( - AstDSL.limit( - AstDSL.relation("schema"), - 1, 1 - ), - AstDSL.alias("integer_value", qualifiedName("integer_value")) - ) - ); + AstDSL.limit(AstDSL.relation("schema"), 1, 1), + AstDSL.alias("integer_value", qualifiedName("integer_value")))); } /** - * SELECT COUNT(NAME) FILTER(WHERE age > 1) FROM test. - * This test is to verify that the aggregator properties are taken - * when wrapping it to {@link org.opensearch.sql.expression.aggregation.NamedAggregator} + * SELECT COUNT(NAME) FILTER(WHERE age > 1) FROM test. This test is to verify that the aggregator + * properties are taken when wrapping it to {@link + * org.opensearch.sql.expression.aggregation.NamedAggregator} */ @Test public void named_aggregator_with_condition() { @@ -1433,36 +1273,37 @@ public void named_aggregator_with_condition() { LogicalPlanDSL.aggregation( LogicalPlanDSL.relation("schema", table), ImmutableList.of( - DSL.named("count(string_value) filter(where integer_value > 1)", - DSL.count(DSL.ref("string_value", STRING)).condition(DSL.greater(DSL.ref( - "integer_value", INTEGER), DSL.literal(1)))) - ), - emptyList() - ), - DSL.named("count(string_value) filter(where integer_value > 1)", DSL.ref( - "count(string_value) filter(where integer_value > 1)", INTEGER)) - ), + DSL.named( + "count(string_value) filter(where integer_value > 1)", + DSL.count(DSL.ref("string_value", STRING)) + .condition( + DSL.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))))), + emptyList()), + DSL.named( + "count(string_value) filter(where integer_value > 1)", + DSL.ref("count(string_value) filter(where integer_value > 1)", INTEGER))), AstDSL.project( AstDSL.agg( AstDSL.relation("schema"), ImmutableList.of( - alias("count(string_value) filter(where integer_value > 1)", filteredAggregate( - "count", qualifiedName("string_value"), function( - ">", qualifiedName("integer_value"), intLiteral(1))))), + alias( + "count(string_value) filter(where integer_value > 1)", + filteredAggregate( + "count", + qualifiedName("string_value"), + function(">", qualifiedName("integer_value"), intLiteral(1))))), emptyList(), emptyList(), - emptyList() - ), - AstDSL.alias("count(string_value) filter(where integer_value > 1)", filteredAggregate( - "count", qualifiedName("string_value"), function( - ">", qualifiedName("integer_value"), intLiteral(1)))) - ) - ); + emptyList()), + AstDSL.alias( + "count(string_value) filter(where integer_value > 1)", + filteredAggregate( + "count", + qualifiedName("string_value"), + function(">", qualifiedName("integer_value"), intLiteral(1)))))); } - /** - * stats avg(integer_value) by string_value span(long_value, 10). - */ + /** stats avg(integer_value) by string_value span(long_value, 10). */ @Test public void ppl_stats_by_fieldAndSpan() { assertAnalyzeEqual( @@ -1489,10 +1330,13 @@ public void parse_relation_with_grok_expression() { LogicalPlanDSL.project( LogicalPlanDSL.relation("schema", table), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING))), - ImmutableList.of(DSL.named("grok_field", - DSL.grok(DSL.ref("string_value", STRING), DSL.literal("%{IPV4:grok_field}"), - DSL.literal("grok_field")))) - ), + ImmutableList.of( + DSL.named( + "grok_field", + DSL.grok( + DSL.ref("string_value", STRING), + DSL.literal("%{IPV4:grok_field}"), + DSL.literal("grok_field"))))), AstDSL.project( AstDSL.parse( AstDSL.relation("schema"), @@ -1500,8 +1344,7 @@ public void parse_relation_with_grok_expression() { AstDSL.field("string_value"), AstDSL.stringLiteral("%{IPV4:grok_field}"), ImmutableMap.of()), - AstDSL.alias("string_value", qualifiedName("string_value")) - )); + AstDSL.alias("string_value", qualifiedName("string_value")))); } @Test @@ -1510,10 +1353,13 @@ public void parse_relation_with_regex_expression() { LogicalPlanDSL.project( LogicalPlanDSL.relation("schema", table), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING))), - ImmutableList.of(DSL.named("group", - DSL.regex(DSL.ref("string_value", STRING), DSL.literal("(?.*)"), - DSL.literal("group")))) - ), + ImmutableList.of( + DSL.named( + "group", + DSL.regex( + DSL.ref("string_value", STRING), + DSL.literal("(?.*)"), + DSL.literal("group"))))), AstDSL.project( AstDSL.parse( AstDSL.relation("schema"), @@ -1521,25 +1367,28 @@ public void parse_relation_with_regex_expression() { AstDSL.field("string_value"), AstDSL.stringLiteral("(?.*)"), ImmutableMap.of()), - AstDSL.alias("string_value", qualifiedName("string_value")) - )); + AstDSL.alias("string_value", qualifiedName("string_value")))); } @Test public void parse_relation_with_patterns_expression() { - Map arguments = ImmutableMap.builder() - .put("new_field", AstDSL.stringLiteral("custom_field")) - .put("pattern", AstDSL.stringLiteral("custom_pattern")) - .build(); + Map arguments = + ImmutableMap.builder() + .put("new_field", AstDSL.stringLiteral("custom_field")) + .put("pattern", AstDSL.stringLiteral("custom_pattern")) + .build(); assertAnalyzeEqual( LogicalPlanDSL.project( LogicalPlanDSL.relation("schema", table), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING))), - ImmutableList.of(DSL.named("custom_field", - DSL.patterns(DSL.ref("string_value", STRING), DSL.literal("custom_pattern"), - DSL.literal("custom_field")))) - ), + ImmutableList.of( + DSL.named( + "custom_field", + DSL.patterns( + DSL.ref("string_value", STRING), + DSL.literal("custom_pattern"), + DSL.literal("custom_field"))))), AstDSL.project( AstDSL.parse( AstDSL.relation("schema"), @@ -1547,8 +1396,7 @@ public void parse_relation_with_patterns_expression() { AstDSL.field("string_value"), AstDSL.stringLiteral("custom_pattern"), arguments), - AstDSL.alias("string_value", qualifiedName("string_value")) - )); + AstDSL.alias("string_value", qualifiedName("string_value")))); } @Test @@ -1557,10 +1405,13 @@ public void parse_relation_with_patterns_expression_no_args() { LogicalPlanDSL.project( LogicalPlanDSL.relation("schema", table), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING))), - ImmutableList.of(DSL.named("patterns_field", - DSL.patterns(DSL.ref("string_value", STRING), DSL.literal(""), - DSL.literal("patterns_field")))) - ), + ImmutableList.of( + DSL.named( + "patterns_field", + DSL.patterns( + DSL.ref("string_value", STRING), + DSL.literal(""), + DSL.literal("patterns_field"))))), AstDSL.project( AstDSL.parse( AstDSL.relation("schema"), @@ -1568,89 +1419,109 @@ public void parse_relation_with_patterns_expression_no_args() { AstDSL.field("string_value"), AstDSL.stringLiteral(""), ImmutableMap.of()), - AstDSL.alias("string_value", qualifiedName("string_value")) - )); + AstDSL.alias("string_value", qualifiedName("string_value")))); } @Test public void kmeanns_relation() { - Map argumentMap = new HashMap() {{ - put("centroids", new Literal(3, DataType.INTEGER)); - put("iterations", new Literal(2, DataType.INTEGER)); - put("distance_type", new Literal("COSINE", DataType.STRING)); - }}; + Map argumentMap = + new HashMap() { + { + put("centroids", new Literal(3, DataType.INTEGER)); + put("iterations", new Literal(2, DataType.INTEGER)); + put("distance_type", new Literal("COSINE", DataType.STRING)); + } + }; assertAnalyzeEqual( - new LogicalMLCommons(LogicalPlanDSL.relation("schema", table), - "kmeans", argumentMap), - new Kmeans(AstDSL.relation("schema"), argumentMap) - ); + new LogicalMLCommons(LogicalPlanDSL.relation("schema", table), "kmeans", argumentMap), + new Kmeans(AstDSL.relation("schema"), argumentMap)); } @Test public void ad_batchRCF_relation() { Map argumentMap = - new HashMap() {{ + new HashMap() { + { put("shingle_size", new Literal(8, DataType.INTEGER)); - }}; + } + }; assertAnalyzeEqual( new LogicalAD(LogicalPlanDSL.relation("schema", table), argumentMap), - new AD(AstDSL.relation("schema"), argumentMap) - ); + new AD(AstDSL.relation("schema"), argumentMap)); } @Test public void ad_fitRCF_relation() { - Map argumentMap = new HashMap() {{ - put("shingle_size", new Literal(8, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - put("time_field", new Literal("timestamp", DataType.STRING)); - }}; + Map argumentMap = + new HashMap() { + { + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal("timestamp", DataType.STRING)); + } + }; assertAnalyzeEqual( - new LogicalAD(LogicalPlanDSL.relation("schema", table), - argumentMap), - new AD(AstDSL.relation("schema"), argumentMap) - ); + new LogicalAD(LogicalPlanDSL.relation("schema", table), argumentMap), + new AD(AstDSL.relation("schema"), argumentMap)); } @Test public void ad_fitRCF_relation_with_time_field() { - Map argumentMap = new HashMap() {{ - put("shingle_size", new Literal(8, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - put("time_field", new Literal("ts", DataType.STRING)); - }}; - - LogicalPlan actual = analyze(AstDSL.project( - new AD(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + Map argumentMap = + new HashMap() { + { + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + put("time_field", new Literal("ts", DataType.STRING)); + } + }; + + LogicalPlan actual = + analyze(AstDSL.project(new AD(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); assertTrue(((LogicalProject) actual).getProjectList().size() >= 3); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named("score", DSL.ref("score", DOUBLE)))); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named("anomaly_grade", DSL.ref("anomaly_grade", DOUBLE)))); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named("ts", DSL.ref("ts", TIMESTAMP)))); } @Test public void ad_fitRCF_relation_without_time_field() { - Map argumentMap = new HashMap<>() {{ - put("shingle_size", new Literal(8, DataType.INTEGER)); - put("time_decay", new Literal(0.0001, DataType.DOUBLE)); - }}; + Map argumentMap = + new HashMap<>() { + { + put("shingle_size", new Literal(8, DataType.INTEGER)); + put("time_decay", new Literal(0.0001, DataType.DOUBLE)); + } + }; - LogicalPlan actual = analyze(AstDSL.project( - new AD(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + LogicalPlan actual = + analyze(AstDSL.project(new AD(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); assertTrue(((LogicalProject) actual).getProjectList().size() >= 2); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named("score", DSL.ref("score", DOUBLE)))); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named("anomalous", DSL.ref("anomalous", BOOLEAN)))); } @Test public void table_function() { - assertAnalyzeEqual(new LogicalRelation("query_range", table), - AstDSL.tableFunction(List.of("prometheus", "query_range"), + assertAnalyzeEqual( + new LogicalRelation("query_range", table), + AstDSL.tableFunction( + List.of("prometheus", "query_range"), unresolvedArg("query", stringLiteral("http_latency")), unresolvedArg("starttime", intLiteral(12345)), unresolvedArg("endtime", intLiteral(12345)), @@ -1659,158 +1530,214 @@ public void table_function() { @Test public void table_function_with_no_datasource() { - ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, - () -> analyze(AstDSL.tableFunction(List.of("query_range"), - unresolvedArg("query", stringLiteral("http_latency")), - unresolvedArg("", intLiteral(12345)), - unresolvedArg("", intLiteral(12345)), - unresolvedArg(null, intLiteral(14))))); - assertEquals("unsupported function name: query_range", - exception.getMessage()); + ExpressionEvaluationException exception = + assertThrows( + ExpressionEvaluationException.class, + () -> + analyze( + AstDSL.tableFunction( + List.of("query_range"), + unresolvedArg("query", stringLiteral("http_latency")), + unresolvedArg("", intLiteral(12345)), + unresolvedArg("", intLiteral(12345)), + unresolvedArg(null, intLiteral(14))))); + assertEquals("unsupported function name: query_range", exception.getMessage()); } @Test public void table_function_with_wrong_datasource() { - ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, - () -> analyze(AstDSL.tableFunction(Arrays.asList("prome", "query_range"), - unresolvedArg("query", stringLiteral("http_latency")), - unresolvedArg("", intLiteral(12345)), - unresolvedArg("", intLiteral(12345)), - unresolvedArg(null, intLiteral(14))))); + ExpressionEvaluationException exception = + assertThrows( + ExpressionEvaluationException.class, + () -> + analyze( + AstDSL.tableFunction( + Arrays.asList("prome", "query_range"), + unresolvedArg("query", stringLiteral("http_latency")), + unresolvedArg("", intLiteral(12345)), + unresolvedArg("", intLiteral(12345)), + unresolvedArg(null, intLiteral(14))))); assertEquals("unsupported function name: prome.query_range", exception.getMessage()); } @Test public void table_function_with_wrong_table_function() { - ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, - () -> analyze(AstDSL.tableFunction(Arrays.asList("prometheus", "queryrange"), - unresolvedArg("query", stringLiteral("http_latency")), - unresolvedArg("", intLiteral(12345)), - unresolvedArg("", intLiteral(12345)), - unresolvedArg(null, intLiteral(14))))); + ExpressionEvaluationException exception = + assertThrows( + ExpressionEvaluationException.class, + () -> + analyze( + AstDSL.tableFunction( + Arrays.asList("prometheus", "queryrange"), + unresolvedArg("query", stringLiteral("http_latency")), + unresolvedArg("", intLiteral(12345)), + unresolvedArg("", intLiteral(12345)), + unresolvedArg(null, intLiteral(14))))); assertEquals("unsupported function name: queryrange", exception.getMessage()); } @Test public void show_datasources() { - assertAnalyzeEqual(new LogicalRelation(DATASOURCES_TABLE_NAME, - new DataSourceTable(dataSourceService)), + assertAnalyzeEqual( + new LogicalRelation(DATASOURCES_TABLE_NAME, new DataSourceTable(dataSourceService)), AstDSL.relation(qualifiedName(DATASOURCES_TABLE_NAME))); } @Test public void ml_relation_unsupported_action() { - Map argumentMap = new HashMap<>() {{ - put(ACTION, new Literal("unsupported", DataType.STRING)); - put(ALGO, new Literal(KMEANS, DataType.STRING)); - }}; + Map argumentMap = + new HashMap<>() { + { + put(ACTION, new Literal("unsupported", DataType.STRING)); + put(ALGO, new Literal(KMEANS, DataType.STRING)); + } + }; IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> analyze(AstDSL.project( - new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields()))); + assertThrows( + IllegalArgumentException.class, + () -> + analyze( + AstDSL.project( + new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields()))); assertEquals( - "Action error. Please indicate train, predict or trainandpredict.", - exception.getMessage()); + "Action error. Please indicate train, predict or trainandpredict.", exception.getMessage()); } @Test public void ml_relation_unsupported_algorithm() { - Map argumentMap = new HashMap<>() {{ - put(ACTION, new Literal(PREDICT, DataType.STRING)); - put(ALGO, new Literal("unsupported", DataType.STRING)); - }}; + Map argumentMap = + new HashMap<>() { + { + put(ACTION, new Literal(PREDICT, DataType.STRING)); + put(ALGO, new Literal("unsupported", DataType.STRING)); + } + }; IllegalArgumentException exception = - assertThrows( - IllegalArgumentException.class, - () -> analyze(AstDSL.project( - new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields()))); - assertEquals( - "Unsupported algorithm: unsupported", - exception.getMessage()); + assertThrows( + IllegalArgumentException.class, + () -> + analyze( + AstDSL.project( + new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields()))); + assertEquals("Unsupported algorithm: unsupported", exception.getMessage()); } @Test public void ml_relation_train_sync() { - Map argumentMap = new HashMap<>() {{ - put(ACTION, new Literal(TRAIN, DataType.STRING)); - put(ALGO, new Literal(KMEANS, DataType.STRING)); - }}; - - LogicalPlan actual = analyze(AstDSL.project( - new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + Map argumentMap = + new HashMap<>() { + { + put(ACTION, new Literal(TRAIN, DataType.STRING)); + put(ALGO, new Literal(KMEANS, DataType.STRING)); + } + }; + + LogicalPlan actual = + analyze(AstDSL.project(new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); assertTrue(((LogicalProject) actual).getProjectList().size() >= 2); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named(STATUS, DSL.ref(STATUS, STRING)))); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named(MODELID, DSL.ref(MODELID, STRING)))); } @Test public void ml_relation_train_async() { - Map argumentMap = new HashMap<>() {{ - put(ACTION, new Literal(TRAIN, DataType.STRING)); - put(ALGO, new Literal(KMEANS, DataType.STRING)); - put(ASYNC, new Literal(true, DataType.BOOLEAN)); - }}; - - LogicalPlan actual = analyze(AstDSL.project( - new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + Map argumentMap = + new HashMap<>() { + { + put(ACTION, new Literal(TRAIN, DataType.STRING)); + put(ALGO, new Literal(KMEANS, DataType.STRING)); + put(ASYNC, new Literal(true, DataType.BOOLEAN)); + } + }; + + LogicalPlan actual = + analyze(AstDSL.project(new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); assertTrue(((LogicalProject) actual).getProjectList().size() >= 2); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named(STATUS, DSL.ref(STATUS, STRING)))); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named(TASKID, DSL.ref(TASKID, STRING)))); } @Test public void ml_relation_predict_kmeans() { - Map argumentMap = new HashMap<>() {{ - put(ACTION, new Literal(PREDICT, DataType.STRING)); - put(ALGO, new Literal(KMEANS, DataType.STRING)); - }}; - - LogicalPlan actual = analyze(AstDSL.project( - new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + Map argumentMap = + new HashMap<>() { + { + put(ACTION, new Literal(PREDICT, DataType.STRING)); + put(ALGO, new Literal(KMEANS, DataType.STRING)); + } + }; + + LogicalPlan actual = + analyze(AstDSL.project(new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); assertTrue(((LogicalProject) actual).getProjectList().size() >= 1); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named(CLUSTERID, DSL.ref(CLUSTERID, INTEGER)))); } @Test public void ml_relation_predict_rcf_with_time_field() { - Map argumentMap = new HashMap<>() {{ - put(ACTION, new Literal(PREDICT, DataType.STRING)); - put(ALGO, new Literal(RCF, DataType.STRING)); - put(RCF_TIME_FIELD, new Literal("ts", DataType.STRING)); - }}; - - LogicalPlan actual = analyze(AstDSL.project( - new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + Map argumentMap = + new HashMap<>() { + { + put(ACTION, new Literal(PREDICT, DataType.STRING)); + put(ALGO, new Literal(RCF, DataType.STRING)); + put(RCF_TIME_FIELD, new Literal("ts", DataType.STRING)); + } + }; + + LogicalPlan actual = + analyze(AstDSL.project(new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); assertTrue(((LogicalProject) actual).getProjectList().size() >= 3); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named(RCF_SCORE, DSL.ref(RCF_SCORE, DOUBLE)))); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named(RCF_ANOMALY_GRADE, DSL.ref(RCF_ANOMALY_GRADE, DOUBLE)))); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named("ts", DSL.ref("ts", TIMESTAMP)))); } @Test public void ml_relation_predict_rcf_without_time_field() { - Map argumentMap = new HashMap<>() {{ - put(ACTION, new Literal(PREDICT, DataType.STRING)); - put(ALGO, new Literal(RCF, DataType.STRING)); - }}; - - LogicalPlan actual = analyze(AstDSL.project( - new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); + Map argumentMap = + new HashMap<>() { + { + put(ACTION, new Literal(PREDICT, DataType.STRING)); + put(ALGO, new Literal(RCF, DataType.STRING)); + } + }; + + LogicalPlan actual = + analyze(AstDSL.project(new ML(AstDSL.relation("schema"), argumentMap), AstDSL.allFields())); assertTrue(((LogicalProject) actual).getProjectList().size() >= 2); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named(RCF_SCORE, DSL.ref(RCF_SCORE, DOUBLE)))); - assertTrue(((LogicalProject) actual).getProjectList() + assertTrue( + ((LogicalProject) actual) + .getProjectList() .contains(DSL.named(RCF_ANOMALOUS, DSL.ref(RCF_ANOMALOUS, BOOLEAN)))); } @@ -1825,8 +1752,10 @@ public void visit_paginate() { void visit_cursor() { LogicalPlan actual = analyze((new FetchCursor("test"))); assertTrue(actual instanceof LogicalFetchCursor); - assertEquals(new LogicalFetchCursor("test", - dataSourceService.getDataSource("@opensearch").getStorageEngine()), actual); + assertEquals( + new LogicalFetchCursor( + "test", dataSourceService.getDataSource("@opensearch").getStorageEngine()), + actual); } @Test @@ -1835,7 +1764,7 @@ public void visit_close_cursor() { assertAll( () -> assertTrue(analyzed instanceof LogicalCloseCursor), () -> assertTrue(analyzed.getChild().get(0) instanceof LogicalFetchCursor), - () -> assertEquals("pewpew", ((LogicalFetchCursor) analyzed.getChild().get(0)).getCursor()) - ); + () -> + assertEquals("pewpew", ((LogicalFetchCursor) analyzed.getChild().get(0)).getCursor())); } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java index b6e2600041..f09bc5d380 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java +++ b/core/src/test/java/org/opensearch/sql/analysis/AnalyzerTestBase.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -47,7 +46,6 @@ import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; - public class AnalyzerTestBase { protected Map typeMapping() { @@ -92,31 +90,34 @@ public Table getTable(DataSourceSchemaName dataSourceSchemaName, String tableNam } protected Table table() { - return Optional.ofNullable(table).orElseGet(() -> new Table() { - @Override - public boolean exists() { - return true; - } - - @Override - public void create(Map schema) { - throw new UnsupportedOperationException("Create table is not supported"); - } - - @Override - public Map getFieldTypes() { - return typeMapping(); - } - - @Override - public PhysicalPlan implement(LogicalPlan plan) { - throw new UnsupportedOperationException(); - } - - public Map getReservedFieldTypes() { - return ImmutableMap.of("_test", STRING); - } - }); + return Optional.ofNullable(table) + .orElseGet( + () -> + new Table() { + @Override + public boolean exists() { + return true; + } + + @Override + public void create(Map schema) { + throw new UnsupportedOperationException("Create table is not supported"); + } + + @Override + public Map getFieldTypes() { + return typeMapping(); + } + + @Override + public PhysicalPlan implement(LogicalPlan plan) { + throw new UnsupportedOperationException(); + } + + public Map getReservedFieldTypes() { + return ImmutableMap.of("_test", STRING); + } + }); } protected DataSourceService dataSourceService() { @@ -125,10 +126,12 @@ protected DataSourceService dataSourceService() { protected SymbolTable symbolTable() { SymbolTable symbolTable = new SymbolTable(); - typeMapping().entrySet() + typeMapping() + .entrySet() .forEach( - entry -> symbolTable - .store(new Symbol(Namespace.FIELD_NAME, entry.getKey()), entry.getValue())); + entry -> + symbolTable.store( + new Symbol(Namespace.FIELD_NAME, entry.getKey()), entry.getValue())); return symbolTable; } @@ -154,8 +157,8 @@ protected Environment typeEnv() { protected Analyzer analyzer = analyzer(expressionAnalyzer(), dataSourceService); - protected Analyzer analyzer(ExpressionAnalyzer expressionAnalyzer, - DataSourceService dataSourceService) { + protected Analyzer analyzer( + ExpressionAnalyzer expressionAnalyzer, DataSourceService dataSourceService) { BuiltinFunctionRepository functionRepository = BuiltinFunctionRepository.getInstance(); return new Analyzer(expressionAnalyzer, dataSourceService, functionRepository); } @@ -182,18 +185,22 @@ protected LogicalPlan analyze(UnresolvedPlan unresolvedPlan) { private class DefaultDataSourceService implements DataSourceService { - private final DataSource opensearchDataSource = new DataSource(DEFAULT_DATASOURCE_NAME, - DataSourceType.OPENSEARCH, storageEngine()); - private final DataSource prometheusDataSource - = new DataSource("prometheus", DataSourceType.PROMETHEUS, prometheusStorageEngine()); - + private final DataSource opensearchDataSource = + new DataSource(DEFAULT_DATASOURCE_NAME, DataSourceType.OPENSEARCH, storageEngine()); + private final DataSource prometheusDataSource = + new DataSource("prometheus", DataSourceType.PROMETHEUS, prometheusStorageEngine()); @Override public Set getDataSourceMetadata(boolean isDefaultDataSourceRequired) { return Stream.of(opensearchDataSource, prometheusDataSource) - .map(ds -> new DataSourceMetadata(ds.getName(), - ds.getConnectorType(),Collections.emptyList(), - ImmutableMap.of())).collect(Collectors.toSet()); + .map( + ds -> + new DataSourceMetadata( + ds.getName(), + ds.getConnectorType(), + Collections.emptyList(), + ImmutableMap.of())) + .collect(Collectors.toSet()); } @Override @@ -216,18 +223,14 @@ public DataSource getDataSource(String dataSourceName) { } @Override - public void updateDataSource(DataSourceMetadata dataSourceMetadata) { - - } + public void updateDataSource(DataSourceMetadata dataSourceMetadata) {} @Override - public void deleteDataSource(String dataSourceName) { - } + public void deleteDataSource(String dataSourceName) {} @Override public Boolean dataSourceExists(String dataSourceName) { - return dataSourceName.equals(DEFAULT_DATASOURCE_NAME) - || dataSourceName.equals("prometheus"); + return dataSourceName.equals(DEFAULT_DATASOURCE_NAME) || dataSourceName.equals("prometheus"); } } @@ -239,8 +242,8 @@ private class TestTableFunctionImplementation implements TableFunctionImplementa private Table table; - public TestTableFunctionImplementation(FunctionName functionName, List arguments, - Table table) { + public TestTableFunctionImplementation( + FunctionName functionName, List arguments, Table table) { this.functionName = functionName; this.arguments = arguments; this.table = table; diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java index 5a05c79132..9d30ebeaab 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionAnalyzerTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static java.util.Collections.emptyList; @@ -57,64 +56,50 @@ class ExpressionAnalyzerTest extends AnalyzerTestBase { public void equal() { assertAnalyzeEqual( DSL.equal(DSL.ref("integer_value", INTEGER), DSL.literal(integerValue(1))), - AstDSL.equalTo(AstDSL.unresolvedAttr("integer_value"), AstDSL.intLiteral(1)) - ); + AstDSL.equalTo(AstDSL.unresolvedAttr("integer_value"), AstDSL.intLiteral(1))); } @Test public void and() { assertAnalyzeEqual( DSL.and(DSL.ref("boolean_value", BOOLEAN), DSL.literal(LITERAL_TRUE)), - AstDSL.and(AstDSL.unresolvedAttr("boolean_value"), AstDSL.booleanLiteral(true)) - ); + AstDSL.and(AstDSL.unresolvedAttr("boolean_value"), AstDSL.booleanLiteral(true))); } @Test public void or() { assertAnalyzeEqual( DSL.or(DSL.ref("boolean_value", BOOLEAN), DSL.literal(LITERAL_TRUE)), - AstDSL.or(AstDSL.unresolvedAttr("boolean_value"), AstDSL.booleanLiteral(true)) - ); + AstDSL.or(AstDSL.unresolvedAttr("boolean_value"), AstDSL.booleanLiteral(true))); } @Test public void xor() { assertAnalyzeEqual( DSL.xor(DSL.ref("boolean_value", BOOLEAN), DSL.literal(LITERAL_TRUE)), - AstDSL.xor(AstDSL.unresolvedAttr("boolean_value"), AstDSL.booleanLiteral(true)) - ); + AstDSL.xor(AstDSL.unresolvedAttr("boolean_value"), AstDSL.booleanLiteral(true))); } @Test public void not() { assertAnalyzeEqual( DSL.not(DSL.ref("boolean_value", BOOLEAN)), - AstDSL.not(AstDSL.unresolvedAttr("boolean_value")) - ); + AstDSL.not(AstDSL.unresolvedAttr("boolean_value"))); } @Test public void qualified_name() { - assertAnalyzeEqual( - DSL.ref("integer_value", INTEGER), - qualifiedName("integer_value") - ); + assertAnalyzeEqual(DSL.ref("integer_value", INTEGER), qualifiedName("integer_value")); } @Test public void between() { assertAnalyzeEqual( DSL.and( - DSL.gte( - DSL.ref("integer_value", INTEGER), - DSL.literal(20)), - DSL.lte( - DSL.ref("integer_value", INTEGER), - DSL.literal(30))), + DSL.gte(DSL.ref("integer_value", INTEGER), DSL.literal(20)), + DSL.lte(DSL.ref("integer_value", INTEGER), DSL.literal(30))), AstDSL.between( - qualifiedName("integer_value"), - AstDSL.intLiteral(20), - AstDSL.intLiteral(30))); + qualifiedName("integer_value"), AstDSL.intLiteral(20), AstDSL.intLiteral(30))); } @Test @@ -149,36 +134,38 @@ public void case_conditions() { AstDSL.caseWhen( null, AstDSL.when( - AstDSL.function(">", - qualifiedName("integer_value"), - AstDSL.intLiteral(50)), AstDSL.stringLiteral("Fifty")), + AstDSL.function(">", qualifiedName("integer_value"), AstDSL.intLiteral(50)), + AstDSL.stringLiteral("Fifty")), AstDSL.when( - AstDSL.function(">", - qualifiedName("integer_value"), - AstDSL.intLiteral(30)), AstDSL.stringLiteral("Thirty")))); + AstDSL.function(">", qualifiedName("integer_value"), AstDSL.intLiteral(30)), + AstDSL.stringLiteral("Thirty")))); } @Test public void castAnalyzer() { assertAnalyzeEqual( DSL.castInt(DSL.ref("boolean_value", BOOLEAN)), - AstDSL.cast(AstDSL.unresolvedAttr("boolean_value"), AstDSL.stringLiteral("INT")) - ); + AstDSL.cast(AstDSL.unresolvedAttr("boolean_value"), AstDSL.stringLiteral("INT"))); - assertThrows(IllegalStateException.class, () -> analyze(AstDSL.cast(AstDSL.unresolvedAttr( - "boolean_value"), AstDSL.stringLiteral("INTERVAL")))); + assertThrows( + IllegalStateException.class, + () -> + analyze( + AstDSL.cast( + AstDSL.unresolvedAttr("boolean_value"), AstDSL.stringLiteral("INTERVAL")))); } @Test public void case_with_default_result_type_different() { - UnresolvedExpression caseWhen = AstDSL.caseWhen( - qualifiedName("integer_value"), - AstDSL.intLiteral(60), - AstDSL.when(AstDSL.intLiteral(30), AstDSL.stringLiteral("Thirty")), - AstDSL.when(AstDSL.intLiteral(50), AstDSL.stringLiteral("Fifty"))); - - SemanticCheckException exception = assertThrows( - SemanticCheckException.class, () -> analyze(caseWhen)); + UnresolvedExpression caseWhen = + AstDSL.caseWhen( + qualifiedName("integer_value"), + AstDSL.intLiteral(60), + AstDSL.when(AstDSL.intLiteral(30), AstDSL.stringLiteral("Thirty")), + AstDSL.when(AstDSL.intLiteral(50), AstDSL.stringLiteral("Fifty"))); + + SemanticCheckException exception = + assertThrows(SemanticCheckException.class, () -> analyze(caseWhen)); assertEquals( "All result types of CASE clause must be the same, but found [STRING, STRING, INTEGER]", exception.getMessage()); @@ -187,8 +174,7 @@ public void case_with_default_result_type_different() { @Test public void scalar_window_function() { assertAnalyzeEqual( - DSL.rank(), - AstDSL.window(AstDSL.function("rank"), emptyList(), emptyList())); + DSL.rank(), AstDSL.window(AstDSL.function("rank"), emptyList(), emptyList())); } @SuppressWarnings("unchecked") @@ -197,9 +183,7 @@ public void aggregate_window_function() { assertAnalyzeEqual( new AggregateWindowFunction(DSL.avg(DSL.ref("integer_value", INTEGER))), AstDSL.window( - AstDSL.aggregate("avg", qualifiedName("integer_value")), - emptyList(), - emptyList())); + AstDSL.aggregate("avg", qualifiedName("integer_value")), emptyList(), emptyList())); } @Test @@ -207,26 +191,24 @@ public void qualified_name_with_qualifier() { analysisContext.push(); analysisContext.peek().define(new Symbol(Namespace.INDEX_NAME, "index_alias"), STRUCT); assertAnalyzeEqual( - DSL.ref("integer_value", INTEGER), - qualifiedName("index_alias", "integer_value") - ); + DSL.ref("integer_value", INTEGER), qualifiedName("index_alias", "integer_value")); analysisContext.peek().define(new Symbol(Namespace.FIELD_NAME, "object_field"), STRUCT); - analysisContext.peek().define(new Symbol(Namespace.FIELD_NAME, "object_field.integer_value"), - INTEGER); + analysisContext + .peek() + .define(new Symbol(Namespace.FIELD_NAME, "object_field.integer_value"), INTEGER); assertAnalyzeEqual( DSL.ref("object_field.integer_value", INTEGER), - qualifiedName("object_field", "integer_value") - ); + qualifiedName("object_field", "integer_value")); SyntaxCheckException exception = - assertThrows(SyntaxCheckException.class, + assertThrows( + SyntaxCheckException.class, () -> analyze(qualifiedName("nested_field", "integer_value"))); assertEquals( "The qualifier [nested_field] of qualified name [nested_field.integer_value] " + "must be an field name, index name or its alias", - exception.getMessage() - ); + exception.getMessage()); analysisContext.pop(); } @@ -237,21 +219,12 @@ public void qualified_name_with_reserved_symbol() { analysisContext.peek().addReservedWord(new Symbol(Namespace.FIELD_NAME, "_reserved"), STRING); analysisContext.peek().addReservedWord(new Symbol(Namespace.FIELD_NAME, "_priority"), FLOAT); analysisContext.peek().define(new Symbol(Namespace.INDEX_NAME, "index_alias"), STRUCT); - assertAnalyzeEqual( - DSL.ref("_priority", FLOAT), - qualifiedName("_priority") - ); - assertAnalyzeEqual( - DSL.ref("_reserved", STRING), - qualifiedName("index_alias", "_reserved") - ); + assertAnalyzeEqual(DSL.ref("_priority", FLOAT), qualifiedName("_priority")); + assertAnalyzeEqual(DSL.ref("_reserved", STRING), qualifiedName("index_alias", "_reserved")); // reserved fields take priority over symbol table analysisContext.peek().define(new Symbol(Namespace.FIELD_NAME, "_reserved"), LONG); - assertAnalyzeEqual( - DSL.ref("_reserved", STRING), - qualifiedName("index_alias", "_reserved") - ); + assertAnalyzeEqual(DSL.ref("_reserved", STRING), qualifiedName("index_alias", "_reserved")); analysisContext.pop(); } @@ -265,9 +238,7 @@ public void interval() { @Test public void all_fields() { - assertAnalyzeEqual( - DSL.literal("*"), - AllFields.of()); + assertAnalyzeEqual(DSL.literal("*"), AllFields.of()); } @Test @@ -281,25 +252,30 @@ public void case_clause() { AstDSL.caseWhen( AstDSL.nullLiteral(), AstDSL.when( - AstDSL.function("=", - qualifiedName("integer_value"), - AstDSL.intLiteral(30)), + AstDSL.function("=", qualifiedName("integer_value"), AstDSL.intLiteral(30)), AstDSL.stringLiteral("test")))); } @Test public void undefined_var_semantic_check_failed() { - SemanticCheckException exception = assertThrows(SemanticCheckException.class, - () -> analyze( - AstDSL.and(AstDSL.unresolvedAttr("undefined_field"), AstDSL.booleanLiteral(true)))); - assertEquals("can't resolve Symbol(namespace=FIELD_NAME, name=undefined_field) in type env", + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> + analyze( + AstDSL.and( + AstDSL.unresolvedAttr("undefined_field"), AstDSL.booleanLiteral(true)))); + assertEquals( + "can't resolve Symbol(namespace=FIELD_NAME, name=undefined_field) in type env", exception.getMessage()); } @Test public void undefined_aggregation_function() { - SemanticCheckException exception = assertThrows(SemanticCheckException.class, - () -> analyze(AstDSL.aggregate("ESTDC_ERROR", field("integer_value")))); + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> analyze(AstDSL.aggregate("ESTDC_ERROR", field("integer_value")))); assertEquals("Unsupported aggregation function ESTDC_ERROR", exception.getMessage()); } @@ -308,25 +284,24 @@ public void aggregation_filter() { assertAnalyzeEqual( DSL.avg(DSL.ref("integer_value", INTEGER)) .condition(DSL.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), - AstDSL.filteredAggregate("avg", qualifiedName("integer_value"), - function(">", qualifiedName("integer_value"), intLiteral(1))) - ); + AstDSL.filteredAggregate( + "avg", + qualifiedName("integer_value"), + function(">", qualifiedName("integer_value"), intLiteral(1)))); } @Test public void variance_mapto_varPop() { assertAnalyzeEqual( DSL.varPop(DSL.ref("integer_value", INTEGER)), - AstDSL.aggregate("variance", qualifiedName("integer_value")) - ); + AstDSL.aggregate("variance", qualifiedName("integer_value"))); } @Test public void distinct_count() { assertAnalyzeEqual( DSL.distinctCount(DSL.ref("integer_value", INTEGER)), - AstDSL.distinctAggregate("count", qualifiedName("integer_value")) - ); + AstDSL.distinctAggregate("count", qualifiedName("integer_value"))); } @Test @@ -334,48 +309,49 @@ public void filtered_distinct_count() { assertAnalyzeEqual( DSL.distinctCount(DSL.ref("integer_value", INTEGER)) .condition(DSL.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), - AstDSL.filteredDistinctCount("count", qualifiedName("integer_value"), function( - ">", qualifiedName("integer_value"), intLiteral(1))) - ); + AstDSL.filteredDistinctCount( + "count", + qualifiedName("integer_value"), + function(">", qualifiedName("integer_value"), intLiteral(1)))); } @Test public void take_aggregation() { assertAnalyzeEqual( DSL.take(DSL.ref("string_value", STRING), DSL.literal(10)), - AstDSL.aggregate("take", qualifiedName("string_value"), intLiteral(10)) - ); + AstDSL.aggregate("take", qualifiedName("string_value"), intLiteral(10))); } @Test public void named_argument() { assertAnalyzeEqual( DSL.namedArgument("arg_name", DSL.literal("query")), - AstDSL.unresolvedArg("arg_name", stringLiteral("query")) - ); + AstDSL.unresolvedArg("arg_name", stringLiteral("query"))); } @Test public void named_parse_expression() { analysisContext.push(); analysisContext.peek().define(new Symbol(Namespace.FIELD_NAME, "string_field"), STRING); - analysisContext.getNamedParseExpressions() - .add(DSL.named("group", - DSL.regex(ref("string_field", STRING), DSL.literal("(?\\d+)"), - DSL.literal("group")))); + analysisContext + .getNamedParseExpressions() + .add( + DSL.named( + "group", + DSL.regex( + ref("string_field", STRING), + DSL.literal("(?\\d+)"), + DSL.literal("group")))); assertAnalyzeEqual( - DSL.regex(ref("string_field", STRING), DSL.literal("(?\\d+)"), - DSL.literal("group")), - qualifiedName("group") - ); + DSL.regex(ref("string_field", STRING), DSL.literal("(?\\d+)"), DSL.literal("group")), + qualifiedName("group")); } @Test public void named_non_parse_expression() { analysisContext.push(); analysisContext.peek().define(new Symbol(Namespace.FIELD_NAME, "string_field"), STRING); - analysisContext.getNamedParseExpressions() - .add(DSL.named("string_field", DSL.literal("123"))); + analysisContext.getNamedParseExpressions().add(DSL.named("string_field", DSL.literal("123"))); assertAnalyzeEqual(DSL.ref("string_field", STRING), qualifiedName("string_field")); } @@ -385,25 +361,29 @@ void match_bool_prefix_expression() { DSL.match_bool_prefix( DSL.namedArgument("field", DSL.literal("field_value1")), DSL.namedArgument("query", DSL.literal("sample query"))), - AstDSL.function("match_bool_prefix", + AstDSL.function( + "match_bool_prefix", AstDSL.unresolvedArg("field", stringLiteral("field_value1")), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @Test void match_bool_prefix_wrong_expression() { - assertThrows(SemanticCheckException.class, - () -> analyze(AstDSL.function("match_bool_prefix", - AstDSL.unresolvedArg("field", stringLiteral("fieldA")), - AstDSL.unresolvedArg("query", floatLiteral(1.2f))))); + assertThrows( + SemanticCheckException.class, + () -> + analyze( + AstDSL.function( + "match_bool_prefix", + AstDSL.unresolvedArg("field", stringLiteral("fieldA")), + AstDSL.unresolvedArg("query", floatLiteral(1.2f))))); } @Test void visit_span() { assertAnalyzeEqual( DSL.span(DSL.ref("integer_value", INTEGER), DSL.literal(1), ""), - AstDSL.span(qualifiedName("integer_value"), intLiteral(1), SpanUnit.NONE) - ); + AstDSL.span(qualifiedName("integer_value"), intLiteral(1), SpanUnit.NONE)); } @Test @@ -425,13 +405,16 @@ void visit_in() { void multi_match_expression() { assertAnalyzeEqual( DSL.multi_match( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F)))))), + DSL.namedArgument( + "fields", + DSL.literal( + new ExprTupleValue( + new LinkedHashMap<>( + ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), - AstDSL.function("multi_match", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field_value1", 1.F))), + AstDSL.function( + "multi_match", + AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -439,14 +422,17 @@ void multi_match_expression() { void multi_match_expression_with_params() { assertAnalyzeEqual( DSL.multi_match( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F)))))), + DSL.namedArgument( + "fields", + DSL.literal( + new ExprTupleValue( + new LinkedHashMap<>( + ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query")), DSL.namedArgument("analyzer", DSL.literal("keyword"))), - AstDSL.function("multi_match", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field_value1", 1.F))), + AstDSL.function( + "multi_match", + AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")), AstDSL.unresolvedArg("analyzer", stringLiteral("keyword")))); } @@ -455,14 +441,20 @@ void multi_match_expression_with_params() { void multi_match_expression_two_fields() { assertAnalyzeEqual( DSL.multi_match( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F), - "field_value2", ExprValueUtils.floatValue(.3F)))))), + DSL.namedArgument( + "fields", + DSL.literal( + new ExprTupleValue( + new LinkedHashMap<>( + ImmutableMap.of( + "field_value1", ExprValueUtils.floatValue(1.F), + "field_value2", ExprValueUtils.floatValue(.3F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), - AstDSL.function("multi_match", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "field_value1", 1.F, "field_value2", .3F))), + AstDSL.function( + "multi_match", + AstDSL.unresolvedArg( + "fields", + new RelevanceFieldList(ImmutableMap.of("field_value1", 1.F, "field_value2", .3F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -470,13 +462,16 @@ void multi_match_expression_two_fields() { void simple_query_string_expression() { assertAnalyzeEqual( DSL.simple_query_string( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F)))))), + DSL.namedArgument( + "fields", + DSL.literal( + new ExprTupleValue( + new LinkedHashMap<>( + ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), - AstDSL.function("simple_query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field_value1", 1.F))), + AstDSL.function( + "simple_query_string", + AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @@ -484,14 +479,17 @@ void simple_query_string_expression() { void simple_query_string_expression_with_params() { assertAnalyzeEqual( DSL.simple_query_string( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F)))))), + DSL.namedArgument( + "fields", + DSL.literal( + new ExprTupleValue( + new LinkedHashMap<>( + ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("sample query")), DSL.namedArgument("analyzer", DSL.literal("keyword"))), - AstDSL.function("simple_query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field_value1", 1.F))), + AstDSL.function( + "simple_query_string", + AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")), AstDSL.unresolvedArg("analyzer", stringLiteral("keyword")))); } @@ -500,37 +498,44 @@ void simple_query_string_expression_with_params() { void simple_query_string_expression_two_fields() { assertAnalyzeEqual( DSL.simple_query_string( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F), - "field_value2", ExprValueUtils.floatValue(.3F)))))), + DSL.namedArgument( + "fields", + DSL.literal( + new ExprTupleValue( + new LinkedHashMap<>( + ImmutableMap.of( + "field_value1", ExprValueUtils.floatValue(1.F), + "field_value2", ExprValueUtils.floatValue(.3F)))))), DSL.namedArgument("query", DSL.literal("sample query"))), - AstDSL.function("simple_query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "field_value1", 1.F, "field_value2", .3F))), + AstDSL.function( + "simple_query_string", + AstDSL.unresolvedArg( + "fields", + new RelevanceFieldList(ImmutableMap.of("field_value1", 1.F, "field_value2", .3F))), AstDSL.unresolvedArg("query", stringLiteral("sample query")))); } @Test void query_expression() { assertAnalyzeEqual( - DSL.query( - DSL.namedArgument("query", DSL.literal("field:query"))), - AstDSL.function("query", - AstDSL.unresolvedArg("query", stringLiteral("field:query")))); + DSL.query(DSL.namedArgument("query", DSL.literal("field:query"))), + AstDSL.function("query", AstDSL.unresolvedArg("query", stringLiteral("field:query")))); } @Test void query_string_expression() { assertAnalyzeEqual( DSL.query_string( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F)))))), + DSL.namedArgument( + "fields", + DSL.literal( + new ExprTupleValue( + new LinkedHashMap<>( + ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("query_value"))), - AstDSL.function("query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field_value1", 1.F))), + AstDSL.function( + "query_string", + AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("query_value")))); } @@ -538,14 +543,17 @@ void query_string_expression() { void query_string_expression_with_params() { assertAnalyzeEqual( DSL.query_string( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F)))))), + DSL.namedArgument( + "fields", + DSL.literal( + new ExprTupleValue( + new LinkedHashMap<>( + ImmutableMap.of("field_value1", ExprValueUtils.floatValue(1.F)))))), DSL.namedArgument("query", DSL.literal("query_value")), DSL.namedArgument("escape", DSL.literal("false"))), - AstDSL.function("query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of( - "field_value1", 1.F))), + AstDSL.function( + "query_string", + AstDSL.unresolvedArg("fields", new RelevanceFieldList(Map.of("field_value1", 1.F))), AstDSL.unresolvedArg("query", stringLiteral("query_value")), AstDSL.unresolvedArg("escape", stringLiteral("false")))); } @@ -554,14 +562,20 @@ void query_string_expression_with_params() { void query_string_expression_two_fields() { assertAnalyzeEqual( DSL.query_string( - DSL.namedArgument("fields", DSL.literal( - new ExprTupleValue(new LinkedHashMap<>(ImmutableMap.of( - "field_value1", ExprValueUtils.floatValue(1.F), - "field_value2", ExprValueUtils.floatValue(.3F)))))), + DSL.namedArgument( + "fields", + DSL.literal( + new ExprTupleValue( + new LinkedHashMap<>( + ImmutableMap.of( + "field_value1", ExprValueUtils.floatValue(1.F), + "field_value2", ExprValueUtils.floatValue(.3F)))))), DSL.namedArgument("query", DSL.literal("query_value"))), - AstDSL.function("query_string", - AstDSL.unresolvedArg("fields", new RelevanceFieldList(ImmutableMap.of( - "field_value1", 1.F, "field_value2", .3F))), + AstDSL.function( + "query_string", + AstDSL.unresolvedArg( + "fields", + new RelevanceFieldList(ImmutableMap.of("field_value1", 1.F, "field_value2", .3F))), AstDSL.unresolvedArg("query", stringLiteral("query_value")))); } @@ -571,7 +585,8 @@ void wildcard_query_expression() { DSL.wildcard_query( DSL.namedArgument("field", DSL.literal("test")), DSL.namedArgument("query", DSL.literal("query_value*"))), - AstDSL.function("wildcard_query", + AstDSL.function( + "wildcard_query", unresolvedArg("field", stringLiteral("test")), unresolvedArg("query", stringLiteral("query_value*")))); } @@ -585,7 +600,8 @@ void wildcard_query_expression_all_params() { DSL.namedArgument("boost", DSL.literal("1.5")), DSL.namedArgument("case_insensitive", DSL.literal("true")), DSL.namedArgument("rewrite", DSL.literal("scoring_boolean"))), - AstDSL.function("wildcard_query", + AstDSL.function( + "wildcard_query", unresolvedArg("field", stringLiteral("test")), unresolvedArg("query", stringLiteral("query_value*")), unresolvedArg("boost", stringLiteral("1.5")), @@ -603,154 +619,144 @@ public void match_phrase_prefix_all_params() { DSL.namedArgument("boost", "1.5"), DSL.namedArgument("analyzer", "standard"), DSL.namedArgument("max_expansions", "4"), - DSL.namedArgument("zero_terms_query", "NONE") - ), - AstDSL.function("match_phrase_prefix", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")), - unresolvedArg("slop", stringLiteral("3")), - unresolvedArg("boost", stringLiteral("1.5")), - unresolvedArg("analyzer", stringLiteral("standard")), - unresolvedArg("max_expansions", stringLiteral("4")), - unresolvedArg("zero_terms_query", stringLiteral("NONE")) - ) - ); - } - - @Test void score_function_expression() { - assertAnalyzeEqual( - DSL.score( - DSL.namedArgument("RelevanceQuery", - DSL.match_phrase_prefix( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("slop", "3") - ) - )), - AstDSL.function("score", - unresolvedArg("RelevanceQuery", - AstDSL.function("match_phrase_prefix", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")), - unresolvedArg("slop", stringLiteral("3")) - ) - ) - ) - ); - } - - @Test void score_function_with_boost() { - assertAnalyzeEqual( - DSL.score( - DSL.namedArgument("RelevanceQuery", - DSL.match_phrase_prefix( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("boost", "3.0") - )), - DSL.namedArgument("boost", "2") - ), - AstDSL.function("score", - unresolvedArg("RelevanceQuery", - AstDSL.function("match_phrase_prefix", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")), - unresolvedArg("boost", stringLiteral("3.0")) - ) - ), - unresolvedArg("boost", stringLiteral("2")) - ) - ); - } - - @Test void score_query_function_expression() { - assertAnalyzeEqual( - DSL.score_query( - DSL.namedArgument("RelevanceQuery", - DSL.wildcard_query( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query") - ) - )), - AstDSL.function("score_query", - unresolvedArg("RelevanceQuery", - AstDSL.function("wildcard_query", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")) - ) - ) - ) - ); - } - - @Test void score_query_function_with_boost() { - assertAnalyzeEqual( - DSL.score_query( - DSL.namedArgument("RelevanceQuery", - DSL.wildcard_query( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query") - ) - ), - DSL.namedArgument("boost", "2.0") - ), - AstDSL.function("score_query", - unresolvedArg("RelevanceQuery", - AstDSL.function("wildcard_query", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")) - ) - ), - unresolvedArg("boost", stringLiteral("2.0")) - ) - ); - } - - @Test void scorequery_function_expression() { - assertAnalyzeEqual( - DSL.scorequery( - DSL.namedArgument("RelevanceQuery", - DSL.simple_query_string( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("slop", "3") - ) - )), - AstDSL.function("scorequery", - unresolvedArg("RelevanceQuery", - AstDSL.function("simple_query_string", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")), - unresolvedArg("slop", stringLiteral("3")) - ) - ) - ) - ); + DSL.namedArgument("zero_terms_query", "NONE")), + AstDSL.function( + "match_phrase_prefix", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")), + unresolvedArg("slop", stringLiteral("3")), + unresolvedArg("boost", stringLiteral("1.5")), + unresolvedArg("analyzer", stringLiteral("standard")), + unresolvedArg("max_expansions", stringLiteral("4")), + unresolvedArg("zero_terms_query", stringLiteral("NONE")))); + } + + @Test + void score_function_expression() { + assertAnalyzeEqual( + DSL.score( + DSL.namedArgument( + "RelevanceQuery", + DSL.match_phrase_prefix( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"), + DSL.namedArgument("slop", "3")))), + AstDSL.function( + "score", + unresolvedArg( + "RelevanceQuery", + AstDSL.function( + "match_phrase_prefix", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")), + unresolvedArg("slop", stringLiteral("3")))))); + } + + @Test + void score_function_with_boost() { + assertAnalyzeEqual( + DSL.score( + DSL.namedArgument( + "RelevanceQuery", + DSL.match_phrase_prefix( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"), + DSL.namedArgument("boost", "3.0"))), + DSL.namedArgument("boost", "2")), + AstDSL.function( + "score", + unresolvedArg( + "RelevanceQuery", + AstDSL.function( + "match_phrase_prefix", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")), + unresolvedArg("boost", stringLiteral("3.0")))), + unresolvedArg("boost", stringLiteral("2")))); + } + + @Test + void score_query_function_expression() { + assertAnalyzeEqual( + DSL.score_query( + DSL.namedArgument( + "RelevanceQuery", + DSL.wildcard_query( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query")))), + AstDSL.function( + "score_query", + unresolvedArg( + "RelevanceQuery", + AstDSL.function( + "wildcard_query", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")))))); + } + + @Test + void score_query_function_with_boost() { + assertAnalyzeEqual( + DSL.score_query( + DSL.namedArgument( + "RelevanceQuery", + DSL.wildcard_query( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"))), + DSL.namedArgument("boost", "2.0")), + AstDSL.function( + "score_query", + unresolvedArg( + "RelevanceQuery", + AstDSL.function( + "wildcard_query", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")))), + unresolvedArg("boost", stringLiteral("2.0")))); + } + + @Test + void scorequery_function_expression() { + assertAnalyzeEqual( + DSL.scorequery( + DSL.namedArgument( + "RelevanceQuery", + DSL.simple_query_string( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"), + DSL.namedArgument("slop", "3")))), + AstDSL.function( + "scorequery", + unresolvedArg( + "RelevanceQuery", + AstDSL.function( + "simple_query_string", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")), + unresolvedArg("slop", stringLiteral("3")))))); } @Test void scorequery_function_with_boost() { assertAnalyzeEqual( - DSL.scorequery( - DSL.namedArgument("RelevanceQuery", - DSL.simple_query_string( - DSL.namedArgument("field", "field_value1"), - DSL.namedArgument("query", "search query"), - DSL.namedArgument("slop", "3") - )), - DSL.namedArgument("boost", "2.0") - ), - AstDSL.function("scorequery", - unresolvedArg("RelevanceQuery", - AstDSL.function("simple_query_string", - unresolvedArg("field", stringLiteral("field_value1")), - unresolvedArg("query", stringLiteral("search query")), - unresolvedArg("slop", stringLiteral("3")) - ) - ), - unresolvedArg("boost", stringLiteral("2.0")) - ) - ); + DSL.scorequery( + DSL.namedArgument( + "RelevanceQuery", + DSL.simple_query_string( + DSL.namedArgument("field", "field_value1"), + DSL.namedArgument("query", "search query"), + DSL.namedArgument("slop", "3"))), + DSL.namedArgument("boost", "2.0")), + AstDSL.function( + "scorequery", + unresolvedArg( + "RelevanceQuery", + AstDSL.function( + "simple_query_string", + unresolvedArg("field", stringLiteral("field_value1")), + unresolvedArg("query", stringLiteral("search query")), + unresolvedArg("slop", stringLiteral("3")))), + unresolvedArg("boost", stringLiteral("2.0")))); } @Test @@ -764,8 +770,12 @@ public void function_returns_non_constant_value() { // Even a function returns the same values - they are calculated on each call // `sysdate()` which returns `LocalDateTime.now()` shouldn't be cached and should return always // different values - var values = List.of(analyze(function("sysdate")), analyze(function("sysdate")), - analyze(function("sysdate")), analyze(function("sysdate"))); + var values = + List.of( + analyze(function("sysdate")), + analyze(function("sysdate")), + analyze(function("sysdate")), + analyze(function("sysdate"))); var referenceValue = analyze(function("sysdate")).valueOf(); assertTrue(values.stream().noneMatch(v -> v.valueOf() == referenceValue)); } @@ -773,8 +783,12 @@ public void function_returns_non_constant_value() { @Test public void now_as_a_function_not_cached() { // // We can call `now()` as a function, in that case nothing should be cached - var values = List.of(analyze(function("now")), analyze(function("now")), - analyze(function("now")), analyze(function("now"))); + var values = + List.of( + analyze(function("now")), + analyze(function("now")), + analyze(function("now")), + analyze(function("now"))); var referenceValue = analyze(function("now")).valueOf(); assertTrue(values.stream().noneMatch(v -> v.valueOf() == referenceValue)); } @@ -783,13 +797,12 @@ protected Expression analyze(UnresolvedExpression unresolvedExpression) { return expressionAnalyzer.analyze(unresolvedExpression, analysisContext); } - protected void assertAnalyzeEqual(Expression expected, - UnresolvedExpression unresolvedExpression) { + protected void assertAnalyzeEqual( + Expression expected, UnresolvedExpression unresolvedExpression) { assertEquals(expected, analyze(unresolvedExpression)); } - protected void assertAnalyzeEqual(Expression expected, - UnresolvedPlan unresolvedPlan) { + protected void assertAnalyzeEqual(Expression expected, UnresolvedPlan unresolvedPlan) { assertEquals(expected, analyze(unresolvedPlan)); } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java b/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java index 89d5f699e3..28bcb8793f 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/ExpressionReferenceOptimizerTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static java.util.Collections.emptyList; @@ -27,65 +26,58 @@ class ExpressionReferenceOptimizerTest extends AnalyzerTestBase { void expression_without_aggregation_should_not_be_replaced() { assertEquals( DSL.subtract(DSL.ref("age", INTEGER), DSL.literal(1)), - optimize(DSL.subtract(DSL.ref("age", INTEGER), DSL.literal(1))) - ); + optimize(DSL.subtract(DSL.ref("age", INTEGER), DSL.literal(1)))); } @Test void group_expression_should_be_replaced() { - assertEquals( - DSL.ref("abs(balance)", INTEGER), - optimize(DSL.abs(DSL.ref("balance", INTEGER))) - ); + assertEquals(DSL.ref("abs(balance)", INTEGER), optimize(DSL.abs(DSL.ref("balance", INTEGER)))); } @Test void aggregation_expression_should_be_replaced() { - assertEquals( - DSL.ref("AVG(age)", DOUBLE), - optimize(DSL.avg(DSL.ref("age", INTEGER))) - ); + assertEquals(DSL.ref("AVG(age)", DOUBLE), optimize(DSL.avg(DSL.ref("age", INTEGER)))); } @Test void aggregation_in_expression_should_be_replaced() { assertEquals( DSL.subtract(DSL.ref("AVG(age)", DOUBLE), DSL.literal(1)), - optimize(DSL.subtract(DSL.avg(DSL.ref("age", INTEGER)), DSL.literal(1))) - ); + optimize(DSL.subtract(DSL.avg(DSL.ref("age", INTEGER)), DSL.literal(1)))); } @Test void case_clause_should_be_replaced() { - Expression caseClause = DSL.cases( - null, - DSL.when( - DSL.equal(DSL.ref("age", INTEGER), DSL.literal(30)), - DSL.literal("true"))); + Expression caseClause = + DSL.cases( + null, + DSL.when(DSL.equal(DSL.ref("age", INTEGER), DSL.literal(30)), DSL.literal("true"))); LogicalPlan logicalPlan = LogicalPlanDSL.aggregation( LogicalPlanDSL.relation("test", table), emptyList(), - ImmutableList.of(DSL.named( - "CaseClause(whenClauses=[WhenClause(condition==(age, 30), result=\"true\")]," - + " defaultResult=null)", - caseClause))); + ImmutableList.of( + DSL.named( + "CaseClause(whenClauses=[WhenClause(condition==(age, 30), result=\"true\")]," + + " defaultResult=null)", + caseClause))); assertEquals( DSL.ref( "CaseClause(whenClauses=[WhenClause(condition==(age, 30), result=\"true\")]," - + " defaultResult=null)", STRING), + + " defaultResult=null)", + STRING), optimize(caseClause, logicalPlan)); } @Test void aggregation_in_case_when_clause_should_be_replaced() { - Expression caseClause = DSL.cases( - null, - DSL.when( - DSL.equal(DSL.avg(DSL.ref("age", INTEGER)), DSL.literal(30)), - DSL.literal("true"))); + Expression caseClause = + DSL.cases( + null, + DSL.when( + DSL.equal(DSL.avg(DSL.ref("age", INTEGER)), DSL.literal(30)), DSL.literal("true"))); LogicalPlan logicalPlan = LogicalPlanDSL.aggregation( @@ -96,19 +88,16 @@ void aggregation_in_case_when_clause_should_be_replaced() { assertEquals( DSL.cases( null, - DSL.when( - DSL.equal(DSL.ref("AVG(age)", DOUBLE), DSL.literal(30)), - DSL.literal("true"))), + DSL.when(DSL.equal(DSL.ref("AVG(age)", DOUBLE), DSL.literal(30)), DSL.literal("true"))), optimize(caseClause, logicalPlan)); } @Test void aggregation_in_case_else_clause_should_be_replaced() { - Expression caseClause = DSL.cases( - DSL.avg(DSL.ref("age", INTEGER)), - DSL.when( - DSL.equal(DSL.ref("age", INTEGER), DSL.literal(30)), - DSL.literal("true"))); + Expression caseClause = + DSL.cases( + DSL.avg(DSL.ref("age", INTEGER)), + DSL.when(DSL.equal(DSL.ref("age", INTEGER), DSL.literal(30)), DSL.literal("true"))); LogicalPlan logicalPlan = LogicalPlanDSL.aggregation( @@ -119,9 +108,7 @@ void aggregation_in_case_else_clause_should_be_replaced() { assertEquals( DSL.cases( DSL.ref("AVG(age)", DOUBLE), - DSL.when( - DSL.equal(DSL.ref("age", INTEGER), DSL.literal(30)), - DSL.literal("true"))), + DSL.when(DSL.equal(DSL.ref("age", INTEGER), DSL.literal(30)), DSL.literal("true"))), optimize(caseClause, logicalPlan)); } @@ -136,12 +123,8 @@ void window_expression_should_be_replaced() { DSL.named(DSL.denseRank()), new WindowDefinition(emptyList(), emptyList())); - assertEquals( - DSL.ref("rank()", INTEGER), - optimize(DSL.rank(), logicalPlan)); - assertEquals( - DSL.ref("dense_rank()", INTEGER), - optimize(DSL.denseRank(), logicalPlan)); + assertEquals(DSL.ref("rank()", INTEGER), optimize(DSL.rank(), logicalPlan)); + assertEquals(DSL.ref("dense_rank()", INTEGER), optimize(DSL.denseRank(), logicalPlan)); } Expression optimize(Expression expression) { @@ -158,11 +141,11 @@ Expression optimize(Expression expression, LogicalPlan logicalPlan) { LogicalPlan logicalPlan() { return LogicalPlanDSL.aggregation( LogicalPlanDSL.relation("schema", table), - ImmutableList - .of(DSL.named("AVG(age)", DSL.avg(DSL.ref("age", INTEGER))), - DSL.named("SUM(age)", DSL.sum(DSL.ref("age", INTEGER)))), - ImmutableList.of(DSL.named("balance", DSL.ref("balance", INTEGER)), - DSL.named("abs(balance)", DSL.abs(DSL.ref("balance", INTEGER)))) - ); + ImmutableList.of( + DSL.named("AVG(age)", DSL.avg(DSL.ref("age", INTEGER))), + DSL.named("SUM(age)", DSL.sum(DSL.ref("age", INTEGER)))), + ImmutableList.of( + DSL.named("balance", DSL.ref("balance", INTEGER)), + DSL.named("abs(balance)", DSL.abs(DSL.ref("balance", INTEGER))))); } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java index e9c891905c..68c508b645 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/NamedExpressionAnalyzerTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -22,8 +21,7 @@ class NamedExpressionAnalyzerTest extends AnalyzerTestBase { void visit_named_select_item() { Alias alias = AstDSL.alias("integer_value", AstDSL.qualifiedName("integer_value")); - NamedExpressionAnalyzer analyzer = - new NamedExpressionAnalyzer(expressionAnalyzer); + NamedExpressionAnalyzer analyzer = new NamedExpressionAnalyzer(expressionAnalyzer); NamedExpression analyze = analyzer.analyze(alias, analysisContext); assertEquals("integer_value", analyze.getNameOrAlias()); @@ -32,11 +30,10 @@ void visit_named_select_item() { @Test void visit_highlight() { Map args = new HashMap<>(); - Alias alias = AstDSL.alias("highlight(fieldA)", - new HighlightFunction( - AstDSL.stringLiteral("fieldA"), args)); - NamedExpressionAnalyzer analyzer = - new NamedExpressionAnalyzer(expressionAnalyzer); + Alias alias = + AstDSL.alias( + "highlight(fieldA)", new HighlightFunction(AstDSL.stringLiteral("fieldA"), args)); + NamedExpressionAnalyzer analyzer = new NamedExpressionAnalyzer(expressionAnalyzer); NamedExpression analyze = analyzer.analyze(alias, analysisContext); assertEquals("highlight(fieldA)", analyze.getNameOrAlias()); diff --git a/core/src/test/java/org/opensearch/sql/analysis/QualifierAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/QualifierAnalyzerTest.java index 5833ef6ae4..3599a86918 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/QualifierAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/QualifierAnalyzerTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -34,18 +33,26 @@ void should_return_original_name_if_no_qualifier() { @Test void should_report_error_if_qualifier_is_not_index() { - runInScope(new Symbol(Namespace.FIELD_NAME, "aIndex"), ARRAY, () -> { - SyntaxCheckException error = assertThrows(SyntaxCheckException.class, - () -> qualifierAnalyzer.unqualified("a", "integer_value")); - assertEquals("The qualifier [a] of qualified name [a.integer_value] " - + "must be an field name, index name or its alias", error.getMessage()); - }); + runInScope( + new Symbol(Namespace.FIELD_NAME, "aIndex"), + ARRAY, + () -> { + SyntaxCheckException error = + assertThrows( + SyntaxCheckException.class, + () -> qualifierAnalyzer.unqualified("a", "integer_value")); + assertEquals( + "The qualifier [a] of qualified name [a.integer_value] " + + "must be an field name, index name or its alias", + error.getMessage()); + }); } @Test void should_report_error_if_qualifier_is_not_exist() { - SyntaxCheckException error = assertThrows(SyntaxCheckException.class, - () -> qualifierAnalyzer.unqualified("a", "integer_value")); + SyntaxCheckException error = + assertThrows( + SyntaxCheckException.class, () -> qualifierAnalyzer.unqualified("a", "integer_value")); assertEquals( "The qualifier [a] of qualified name [a.integer_value] must be an field name, index name " + "or its alias", @@ -54,23 +61,26 @@ void should_report_error_if_qualifier_is_not_exist() { @Test void should_return_qualified_name_if_qualifier_is_index() { - runInScope(new Symbol(Namespace.INDEX_NAME, "a"), STRUCT, () -> - assertEquals("integer_value", qualifierAnalyzer.unqualified("a", "integer_value")) - ); + runInScope( + new Symbol(Namespace.INDEX_NAME, "a"), + STRUCT, + () -> assertEquals("integer_value", qualifierAnalyzer.unqualified("a", "integer_value"))); } @Test void should_return_qualified_name_if_qualifier_is_field() { - runInScope(new Symbol(Namespace.FIELD_NAME, "a"), STRUCT, () -> - assertEquals("a.integer_value", qualifierAnalyzer.unqualified("a", "integer_value")) - ); + runInScope( + new Symbol(Namespace.FIELD_NAME, "a"), + STRUCT, + () -> assertEquals("a.integer_value", qualifierAnalyzer.unqualified("a", "integer_value"))); } @Test void should_report_error_if_more_parts_in_qualified_name() { - runInScope(new Symbol(Namespace.INDEX_NAME, "a"), STRUCT, () -> - qualifierAnalyzer.unqualified("a", "integer_value", "invalid") - ); + runInScope( + new Symbol(Namespace.INDEX_NAME, "a"), + STRUCT, + () -> qualifierAnalyzer.unqualified("a", "integer_value", "invalid")); } private void runInScope(Symbol symbol, ExprType type, Runnable test) { @@ -82,5 +92,4 @@ private void runInScope(Symbol symbol, ExprType type, Runnable test) { analysisContext.pop(); } } - } diff --git a/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java b/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java index 3bd90f0081..27edc588fa 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/SelectAnalyzeTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.opensearch.sql.ast.dsl.AstDSL.argument; @@ -44,8 +43,7 @@ public void project_all_from_source() { DSL.named("double_value", DSL.ref("double_value", DOUBLE)), DSL.named("integer_value", DSL.ref("integer_value", INTEGER)), DSL.named("double_value", DSL.ref("double_value", DOUBLE)), - DSL.named("string_value", DSL.ref("string_value", STRING)) - ), + DSL.named("string_value", DSL.ref("string_value", STRING))), AstDSL.projectWithArg( AstDSL.relation("schema"), AstDSL.defaultFieldsArgs(), @@ -61,11 +59,9 @@ public void select_and_project_all() { LogicalPlanDSL.project( LogicalPlanDSL.relation("schema", table), DSL.named("integer_value", DSL.ref("integer_value", INTEGER)), - DSL.named("double_value", DSL.ref("double_value", DOUBLE)) - ), + DSL.named("double_value", DSL.ref("double_value", DOUBLE))), DSL.named("integer_value", DSL.ref("integer_value", INTEGER)), - DSL.named("double_value", DSL.ref("double_value", DOUBLE)) - ), + DSL.named("double_value", DSL.ref("double_value", DOUBLE))), AstDSL.projectWithArg( AstDSL.projectWithArg( AstDSL.relation("schema"), @@ -73,8 +69,7 @@ public void select_and_project_all() { AstDSL.field("integer_value"), AstDSL.field("double_value")), AstDSL.defaultFieldsArgs(), - AllFields.of() - )); + AllFields.of())); } @Test @@ -84,10 +79,8 @@ public void remove_and_project_all() { LogicalPlanDSL.remove( LogicalPlanDSL.relation("schema", table), DSL.ref("integer_value", INTEGER), - DSL.ref("double_value", DOUBLE) - ), - DSL.named("string_value", DSL.ref("string_value", STRING)) - ), + DSL.ref("double_value", DOUBLE)), + DSL.named("string_value", DSL.ref("string_value", STRING))), AstDSL.projectWithArg( AstDSL.projectWithArg( AstDSL.relation("schema"), @@ -95,8 +88,7 @@ public void remove_and_project_all() { AstDSL.field("integer_value"), AstDSL.field("double_value")), AstDSL.defaultFieldsArgs(), - AllFields.of() - )); + AllFields.of())); } @Test @@ -105,20 +97,21 @@ public void stats_and_project_all() { LogicalPlanDSL.project( LogicalPlanDSL.aggregation( LogicalPlanDSL.relation("schema", table), - ImmutableList.of(DSL - .named("avg(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), + ImmutableList.of( + DSL.named("avg(integer_value)", DSL.avg(DSL.ref("integer_value", INTEGER)))), ImmutableList.of(DSL.named("string_value", DSL.ref("string_value", STRING)))), DSL.named("avg(integer_value)", DSL.ref("avg(integer_value)", DOUBLE)), - DSL.named("string_value", DSL.ref("string_value", STRING)) - ), + DSL.named("string_value", DSL.ref("string_value", STRING))), AstDSL.projectWithArg( AstDSL.agg( AstDSL.relation("schema"), - AstDSL.exprList(AstDSL.alias("avg(integer_value)", AstDSL.aggregate("avg", - field("integer_value")))), + AstDSL.exprList( + AstDSL.alias( + "avg(integer_value)", AstDSL.aggregate("avg", field("integer_value")))), null, ImmutableList.of(AstDSL.alias("string_value", field("string_value"))), - AstDSL.defaultStatsArgs()), AstDSL.defaultFieldsArgs(), + AstDSL.defaultStatsArgs()), + AstDSL.defaultFieldsArgs(), AllFields.of())); } @@ -131,14 +124,12 @@ public void rename_and_project_all() { ImmutableMap.of(DSL.ref("integer_value", INTEGER), DSL.ref("ivalue", INTEGER))), DSL.named("double_value", DSL.ref("double_value", DOUBLE)), DSL.named("string_value", DSL.ref("string_value", STRING)), - DSL.named("ivalue", DSL.ref("ivalue", INTEGER)) - ), + DSL.named("ivalue", DSL.ref("ivalue", INTEGER))), AstDSL.projectWithArg( AstDSL.rename( AstDSL.relation("schema"), AstDSL.map(AstDSL.field("integer_value"), AstDSL.field("ivalue"))), AstDSL.defaultFieldsArgs(), - AllFields.of() - )); + AllFields.of())); } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/SelectExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/SelectExpressionAnalyzerTest.java index b2fe29b509..38d4704bcd 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/SelectExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/SelectExpressionAnalyzerTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -28,23 +27,20 @@ @ExtendWith(MockitoExtension.class) public class SelectExpressionAnalyzerTest extends AnalyzerTestBase { - @Mock - private ExpressionReferenceOptimizer optimizer; + @Mock private ExpressionReferenceOptimizer optimizer; @Test public void named_expression() { assertAnalyzeEqual( DSL.named("integer_value", DSL.ref("integer_value", INTEGER)), - AstDSL.alias("integer_value", AstDSL.qualifiedName("integer_value")) - ); + AstDSL.alias("integer_value", AstDSL.qualifiedName("integer_value"))); } @Test public void named_expression_with_alias() { assertAnalyzeEqual( DSL.named("integer_value", DSL.ref("integer_value", INTEGER), "int"), - AstDSL.alias("integer_value", AstDSL.qualifiedName("integer_value"), "int") - ); + AstDSL.alias("integer_value", AstDSL.qualifiedName("integer_value"), "int")); } @Test @@ -52,9 +48,8 @@ public void field_name_with_qualifier() { analysisContext.peek().define(new Symbol(Namespace.INDEX_NAME, "index_alias"), STRUCT); assertAnalyzeEqual( DSL.named("integer_value", DSL.ref("integer_value", INTEGER)), - AstDSL.alias("integer_alias.integer_value", - AstDSL.qualifiedName("index_alias", "integer_value")) - ); + AstDSL.alias( + "integer_alias.integer_value", AstDSL.qualifiedName("index_alias", "integer_value"))); } @Test @@ -62,9 +57,9 @@ public void field_name_with_qualifier_quoted() { analysisContext.peek().define(new Symbol(Namespace.INDEX_NAME, "index_alias"), STRUCT); assertAnalyzeEqual( DSL.named("integer_value", DSL.ref("integer_value", INTEGER)), - AstDSL.alias("`integer_alias`.integer_value", // qualifier in SELECT is quoted originally - AstDSL.qualifiedName("index_alias", "integer_value")) - ); + AstDSL.alias( + "`integer_alias`.integer_value", // qualifier in SELECT is quoted originally + AstDSL.qualifiedName("index_alias", "integer_value"))); } @Test @@ -72,21 +67,21 @@ public void field_name_in_expression_with_qualifier() { analysisContext.peek().define(new Symbol(Namespace.INDEX_NAME, "index_alias"), STRUCT); assertAnalyzeEqual( DSL.named("abs(index_alias.integer_value)", DSL.abs(DSL.ref("integer_value", INTEGER))), - AstDSL.alias("abs(index_alias.integer_value)", - AstDSL.function("abs", AstDSL.qualifiedName("index_alias", "integer_value"))) - ); + AstDSL.alias( + "abs(index_alias.integer_value)", + AstDSL.function("abs", AstDSL.qualifiedName("index_alias", "integer_value")))); } protected List analyze(UnresolvedExpression unresolvedExpression) { - doAnswer(invocation -> ((NamedExpression) invocation.getArgument(0)) - .getDelegated()).when(optimizer).optimize(any(), any()); + doAnswer(invocation -> ((NamedExpression) invocation.getArgument(0)).getDelegated()) + .when(optimizer) + .optimize(any(), any()); return new SelectExpressionAnalyzer(expressionAnalyzer) - .analyze(Arrays.asList(unresolvedExpression), - analysisContext, optimizer); + .analyze(Arrays.asList(unresolvedExpression), analysisContext, optimizer); } - protected void assertAnalyzeEqual(NamedExpression expected, - UnresolvedExpression unresolvedExpression) { + protected void assertAnalyzeEqual( + NamedExpression expected, UnresolvedExpression unresolvedExpression) { assertEquals(Arrays.asList(expected), analyze(unresolvedExpression)); } } diff --git a/core/src/test/java/org/opensearch/sql/analysis/TypeEnvironmentTest.java b/core/src/test/java/org/opensearch/sql/analysis/TypeEnvironmentTest.java index c963e1d30d..91677a901e 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/TypeEnvironmentTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/TypeEnvironmentTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -21,9 +20,7 @@ public class TypeEnvironmentTest { - /** - * Use context class for push/pop. - */ + /** Use context class for push/pop. */ private AnalysisContext context = new AnalysisContext(); @Test @@ -69,20 +66,24 @@ public void defineFieldSymbolInDifferentEnvironmentsShouldNotAbleToResolveOncePo assertEquals(INTEGER, environment().resolve(toSymbol(age))); SemanticCheckException exception = assertThrows(SemanticCheckException.class, () -> environment().resolve(toSymbol(city))); - assertEquals("can't resolve Symbol(namespace=FIELD_NAME, name=s.city) in type env", + assertEquals( + "can't resolve Symbol(namespace=FIELD_NAME, name=s.city) in type env", exception.getMessage()); - exception = assertThrows(SemanticCheckException.class, - () -> environment().resolve(toSymbol(manager))); - assertEquals("can't resolve Symbol(namespace=FIELD_NAME, name=s.manager) in type env", + exception = + assertThrows(SemanticCheckException.class, () -> environment().resolve(toSymbol(manager))); + assertEquals( + "can't resolve Symbol(namespace=FIELD_NAME, name=s.manager) in type env", exception.getMessage()); } @Test public void resolveLiteralInEnvFailed() { - SemanticCheckException exception = assertThrows(SemanticCheckException.class, - () -> environment().resolve(new Symbol(Namespace.FIELD_NAME, "1"))); - assertEquals("can't resolve Symbol(namespace=FIELD_NAME, name=1) in type env", - exception.getMessage()); + SemanticCheckException exception = + assertThrows( + SemanticCheckException.class, + () -> environment().resolve(new Symbol(Namespace.FIELD_NAME, "1"))); + assertEquals( + "can't resolve Symbol(namespace=FIELD_NAME, name=1) in type env", exception.getMessage()); } private TypeEnvironment environment() { @@ -92,5 +93,4 @@ private TypeEnvironment environment() { private Symbol toSymbol(ReferenceExpression ref) { return new Symbol(Namespace.FIELD_NAME, ref.getAttr()); } - } diff --git a/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java b/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java index dd4361ad6a..acb11f0b57 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/WindowExpressionAnalyzerTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -75,16 +74,12 @@ void should_not_generate_sort_operator_if_no_partition_by_and_order_by_list() { LogicalPlanDSL.window( LogicalPlanDSL.relation("test", table), DSL.named("row_number", DSL.rowNumber()), - new WindowDefinition( - ImmutableList.of(), - ImmutableList.of())), + new WindowDefinition(ImmutableList.of(), ImmutableList.of())), analyzer.analyze( AstDSL.alias( "row_number", AstDSL.window( - AstDSL.function("row_number"), - ImmutableList.of(), - ImmutableList.of())), + AstDSL.function("row_number"), ImmutableList.of(), ImmutableList.of())), analysisContext)); } @@ -93,10 +88,7 @@ void should_return_original_child_if_project_item_not_windowed() { assertEquals( child, analyzer.analyze( - AstDSL.alias( - "string_value", - AstDSL.qualifiedName("string_value")), - analysisContext)); + AstDSL.alias("string_value", AstDSL.qualifiedName("string_value")), analysisContext)); } @Test @@ -114,20 +106,23 @@ void can_analyze_sort_options() { .put(new SortOption(DESC, NULL_LAST), DEFAULT_DESC) .build(); - expects.forEach((option, expect) -> { - Alias ast = AstDSL.alias( - "row_number", - AstDSL.window( - AstDSL.function("row_number"), - Collections.emptyList(), - ImmutableList.of( - ImmutablePair.of(option, AstDSL.qualifiedName("integer_value"))))); + expects.forEach( + (option, expect) -> { + Alias ast = + AstDSL.alias( + "row_number", + AstDSL.window( + AstDSL.function("row_number"), + Collections.emptyList(), + ImmutableList.of( + ImmutablePair.of(option, AstDSL.qualifiedName("integer_value"))))); - LogicalPlan plan = analyzer.analyze(ast, analysisContext); - LogicalSort sort = (LogicalSort) plan.getChild().get(0); - assertEquals(expect, sort.getSortList().get(0).getLeft(), - "Assertion failed on input option: " + option); - }); + LogicalPlan plan = analyzer.analyze(ast, analysisContext); + LogicalSort sort = (LogicalSort) plan.getChild().get(0); + assertEquals( + expect, + sort.getSortList().get(0).getLeft(), + "Assertion failed on input option: " + option); + }); } - } diff --git a/core/src/test/java/org/opensearch/sql/analysis/model/DataSourceSchemaIdentifierNameResolverTest.java b/core/src/test/java/org/opensearch/sql/analysis/model/DataSourceSchemaIdentifierNameResolverTest.java index c00bd7705d..775984a528 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/model/DataSourceSchemaIdentifierNameResolverTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/model/DataSourceSchemaIdentifierNameResolverTest.java @@ -7,7 +7,6 @@ package org.opensearch.sql.analysis.model; - import static org.junit.jupiter.api.Assertions.assertEquals; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; @@ -27,14 +26,12 @@ @ExtendWith(MockitoExtension.class) public class DataSourceSchemaIdentifierNameResolverTest { - @Mock - private DataSourceService dataSourceService; + @Mock private DataSourceService dataSourceService; @Test void testFullyQualifiedName() { when(dataSourceService.dataSourceExists("prom")).thenReturn(Boolean.TRUE); - identifierOf( - Arrays.asList("prom", "information_schema", "tables"), dataSourceService) + identifierOf(Arrays.asList("prom", "information_schema", "tables"), dataSourceService) .datasource("prom") .schema("information_schema") .name("tables"); @@ -66,8 +63,8 @@ void defaultDataSourceNameResolve() { static class Identifier { private final DataSourceSchemaIdentifierNameResolver resolver; - protected static Identifier identifierOf(List parts, - DataSourceService dataSourceService) { + protected static Identifier identifierOf( + List parts, DataSourceService dataSourceService) { return new Identifier(parts, dataSourceService); } diff --git a/core/src/test/java/org/opensearch/sql/analysis/symbol/SymbolTableTest.java b/core/src/test/java/org/opensearch/sql/analysis/symbol/SymbolTableTest.java index 90f98e8492..176390560e 100644 --- a/core/src/test/java/org/opensearch/sql/analysis/symbol/SymbolTableTest.java +++ b/core/src/test/java/org/opensearch/sql/analysis/symbol/SymbolTableTest.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.analysis.symbol; import static org.hamcrest.MatcherAssert.assertThat; @@ -24,7 +23,6 @@ import org.junit.jupiter.api.Test; import org.opensearch.sql.data.type.ExprType; - public class SymbolTableTest { private SymbolTable symbolTable; @@ -60,13 +58,7 @@ public void defineFieldSymbolShouldBeAbleToResolveByPrefix() { Map typeByName = symbolTable.lookupByPrefix(new Symbol(Namespace.FIELD_NAME, "s.projects")); - assertThat( - typeByName, - allOf( - aMapWithSize(1), - hasEntry("s.projects.active", BOOLEAN) - ) - ); + assertThat(typeByName, allOf(aMapWithSize(1), hasEntry("s.projects.active", BOOLEAN))); } @Test @@ -76,17 +68,11 @@ public void lookupAllFieldsReturnUnnestedFields() { symbolTable.store(new Symbol(Namespace.FIELD_NAME, "active.manager.name"), STRING); symbolTable.store(new Symbol(Namespace.FIELD_NAME, "s.address"), BOOLEAN); - Map typeByName = - symbolTable.lookupAllFields(Namespace.FIELD_NAME); + Map typeByName = symbolTable.lookupAllFields(Namespace.FIELD_NAME); assertThat( typeByName, - allOf( - aMapWithSize(2), - hasEntry("active", BOOLEAN), - hasEntry("s.address", BOOLEAN) - ) - ); + allOf(aMapWithSize(2), hasEntry("active", BOOLEAN), hasEntry("s.address", BOOLEAN))); } @Test @@ -94,8 +80,8 @@ public void failedToResolveSymbolNoNamespaceMatched() { symbolTable.store(new Symbol(Namespace.FUNCTION_NAME, "customFunction"), BOOLEAN); assertFalse(symbolTable.lookup(new Symbol(Namespace.FIELD_NAME, "s.projects")).isPresent()); - assertThat(symbolTable.lookupByPrefix(new Symbol(Namespace.FIELD_NAME, "s.projects")), - anEmptyMap()); + assertThat( + symbolTable.lookupByPrefix(new Symbol(Namespace.FIELD_NAME, "s.projects")), anEmptyMap()); } @Test @@ -111,5 +97,4 @@ private void defineSymbolShouldBeAbleToResolve(Symbol symbol, ExprType expectedT assertTrue(actualType.isPresent()); assertEquals(expectedType, actualType.get()); } - } diff --git a/core/src/test/java/org/opensearch/sql/config/TestConfig.java b/core/src/test/java/org/opensearch/sql/config/TestConfig.java index 6179f020c2..92b6aac64f 100644 --- a/core/src/test/java/org/opensearch/sql/config/TestConfig.java +++ b/core/src/test/java/org/opensearch/sql/config/TestConfig.java @@ -3,7 +3,6 @@ * SPDX-License-Identifier: Apache-2.0 */ - package org.opensearch.sql.config; import com.google.common.collect.ImmutableMap; @@ -23,9 +22,7 @@ import org.opensearch.sql.storage.StorageEngine; import org.opensearch.sql.storage.Table; -/** - * Configuration will be used for UT. - */ +/** Configuration will be used for UT. */ public class TestConfig { public static final String INT_TYPE_NULL_VALUE_FIELD = "int_null_value"; public static final String INT_TYPE_MISSING_VALUE_FIELD = "int_missing_value"; @@ -36,32 +33,33 @@ public class TestConfig { public static final String STRING_TYPE_NULL_VALUE_FIELD = "string_null_value"; public static final String STRING_TYPE_MISSING_VALUE_FIELD = "string_missing_value"; - public static Map typeMapping = new ImmutableMap.Builder() - .put("integer_value", ExprCoreType.INTEGER) - .put(INT_TYPE_NULL_VALUE_FIELD, ExprCoreType.INTEGER) - .put(INT_TYPE_MISSING_VALUE_FIELD, ExprCoreType.INTEGER) - .put("long_value", ExprCoreType.LONG) - .put("float_value", ExprCoreType.FLOAT) - .put("double_value", ExprCoreType.DOUBLE) - .put(DOUBLE_TYPE_NULL_VALUE_FIELD, ExprCoreType.DOUBLE) - .put(DOUBLE_TYPE_MISSING_VALUE_FIELD, ExprCoreType.DOUBLE) - .put("boolean_value", ExprCoreType.BOOLEAN) - .put(BOOL_TYPE_NULL_VALUE_FIELD, ExprCoreType.BOOLEAN) - .put(BOOL_TYPE_MISSING_VALUE_FIELD, ExprCoreType.BOOLEAN) - .put("string_value", ExprCoreType.STRING) - .put(STRING_TYPE_NULL_VALUE_FIELD, ExprCoreType.STRING) - .put(STRING_TYPE_MISSING_VALUE_FIELD, ExprCoreType.STRING) - .put("struct_value", ExprCoreType.STRUCT) - .put("array_value", ExprCoreType.ARRAY) - .put("timestamp_value", ExprCoreType.TIMESTAMP) - .put("field_value1", ExprCoreType.STRING) - .put("field_value2", ExprCoreType.STRING) - .put("message", ExprCoreType.STRING) - .put("message.info", ExprCoreType.STRING) - .put("message.info.id", ExprCoreType.STRING) - .put("comment", ExprCoreType.STRING) - .put("comment.data", ExprCoreType.STRING) - .build(); + public static Map typeMapping = + new ImmutableMap.Builder() + .put("integer_value", ExprCoreType.INTEGER) + .put(INT_TYPE_NULL_VALUE_FIELD, ExprCoreType.INTEGER) + .put(INT_TYPE_MISSING_VALUE_FIELD, ExprCoreType.INTEGER) + .put("long_value", ExprCoreType.LONG) + .put("float_value", ExprCoreType.FLOAT) + .put("double_value", ExprCoreType.DOUBLE) + .put(DOUBLE_TYPE_NULL_VALUE_FIELD, ExprCoreType.DOUBLE) + .put(DOUBLE_TYPE_MISSING_VALUE_FIELD, ExprCoreType.DOUBLE) + .put("boolean_value", ExprCoreType.BOOLEAN) + .put(BOOL_TYPE_NULL_VALUE_FIELD, ExprCoreType.BOOLEAN) + .put(BOOL_TYPE_MISSING_VALUE_FIELD, ExprCoreType.BOOLEAN) + .put("string_value", ExprCoreType.STRING) + .put(STRING_TYPE_NULL_VALUE_FIELD, ExprCoreType.STRING) + .put(STRING_TYPE_MISSING_VALUE_FIELD, ExprCoreType.STRING) + .put("struct_value", ExprCoreType.STRUCT) + .put("array_value", ExprCoreType.ARRAY) + .put("timestamp_value", ExprCoreType.TIMESTAMP) + .put("field_value1", ExprCoreType.STRING) + .put("field_value2", ExprCoreType.STRING) + .put("message", ExprCoreType.STRING) + .put("message.info", ExprCoreType.STRING) + .put("message.info.id", ExprCoreType.STRING) + .put("comment", ExprCoreType.STRING) + .put("comment.data", ExprCoreType.STRING) + .build(); protected StorageEngine storageEngine() { return new StorageEngine() { @@ -94,10 +92,12 @@ public PhysicalPlan implement(LogicalPlan plan) { protected SymbolTable symbolTable() { SymbolTable symbolTable = new SymbolTable(); - typeMapping.entrySet() + typeMapping + .entrySet() .forEach( - entry -> symbolTable - .store(new Symbol(Namespace.FIELD_NAME, entry.getKey()), entry.getValue())); + entry -> + symbolTable.store( + new Symbol(Namespace.FIELD_NAME, entry.getKey()), entry.getValue())); return symbolTable; }