diff --git a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java index 48098b9741..5010e41942 100644 --- a/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java +++ b/core/src/main/java/org/opensearch/sql/datasource/model/DataSourceType.java @@ -8,8 +8,7 @@ public enum DataSourceType { PROMETHEUS("prometheus"), OPENSEARCH("opensearch"), - JDBC("jdbc"); - + SPARK("spark"); private String text; DataSourceType(String text) { diff --git a/plugin/build.gradle b/plugin/build.gradle index 42d5723194..11f97ea857 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -126,6 +126,7 @@ dependencies { api project(':opensearch') api project(':prometheus') api project(':datasources') + api project(':spark') testImplementation group: 'net.bytebuddy', name: 'byte-buddy-agent', version: '1.12.13' testImplementation group: 'org.hamcrest', name: 'hamcrest-library', version: '2.1' diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java index 36986c9afc..7e867be967 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/SQLPlugin.java @@ -81,6 +81,7 @@ import org.opensearch.sql.plugin.transport.TransportPPLQueryAction; import org.opensearch.sql.plugin.transport.TransportPPLQueryResponse; import org.opensearch.sql.prometheus.storage.PrometheusStorageFactory; +import org.opensearch.sql.spark.storage.SparkStorageFactory; import org.opensearch.sql.storage.DataSourceFactory; import org.opensearch.threadpool.ExecutorBuilder; import org.opensearch.threadpool.FixedExecutorBuilder; @@ -221,6 +222,7 @@ private DataSourceServiceImpl createDataSourceService() { .add(new OpenSearchDataSourceFactory( new OpenSearchNodeClient(this.client), pluginSettings)) .add(new PrometheusStorageFactory(pluginSettings)) + .add(new SparkStorageFactory(this.client, pluginSettings)) .build(), dataSourceMetadataStorage, dataSourceUserAuthorizationHelper); diff --git a/settings.gradle b/settings.gradle index 6f7214cb3a..2140ad6c9e 100644 --- a/settings.gradle +++ b/settings.gradle @@ -19,4 +19,6 @@ include 'legacy' include 'sql' include 'prometheus' include 'benchmarks' -include 'datasources' \ No newline at end of file +include 'datasources' +include 'spark' + diff --git a/spark/build.gradle b/spark/build.gradle new file mode 100644 index 0000000000..58103dc67f --- /dev/null +++ b/spark/build.gradle @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +plugins { + id 'java-library' + id "io.freefair.lombok" + id 'jacoco' +} + +repositories { + mavenCentral() +} + +dependencies { + api project(':core') + implementation project(':datasources') + + implementation group: 'org.opensearch', name: 'opensearch', version: "${opensearch_version}" + implementation group: 'org.json', name: 'json', version: '20230227' + + testImplementation('org.junit.jupiter:junit-jupiter:5.6.2') + testImplementation group: 'org.mockito', name: 'mockito-core', version: '3.12.4' + testImplementation group: 'org.mockito', name: 'mockito-junit-jupiter', version: '3.12.4' +} + +test { + useJUnitPlatform() + testLogging { + events "passed", "skipped", "failed" + exceptionFormat "full" + } +} + +jacocoTestReport { + reports { + html.enabled true + xml.enabled true + } + afterEvaluate { + classDirectories.setFrom(files(classDirectories.files.collect { + fileTree(dir: it) + })) + } +} +test.finalizedBy(project.tasks.jacocoTestReport) + +jacocoTestCoverageVerification { + violationRules { + rule { + element = 'CLASS' + excludes = [ + 'org.opensearch.sql.spark.data.constants.*' + ] + limit { + counter = 'LINE' + minimum = 1.0 + } + limit { + counter = 'BRANCH' + minimum = 1.0 + } + } + } + afterEvaluate { + classDirectories.setFrom(files(classDirectories.files.collect { + fileTree(dir: it) + })) + } +} +check.dependsOn jacocoTestCoverageVerification +jacocoTestCoverageVerification.dependsOn jacocoTestReport diff --git a/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java b/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java new file mode 100644 index 0000000000..99d8600dd0 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/client/SparkClient.java @@ -0,0 +1,22 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.client; + +import java.io.IOException; +import org.json.JSONObject; + +/** + * Interface class for Spark Client. + */ +public interface SparkClient { + /** + * This method executes spark sql query. + * + * @param query spark sql query + * @return spark query response + */ + JSONObject sql(String query) throws IOException; +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java b/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java new file mode 100644 index 0000000000..1936c266de --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/implementation/SparkSqlFunctionImplementation.java @@ -0,0 +1,102 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.implementation; + +import static org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver.QUERY; + +import java.util.List; +import java.util.stream.Collectors; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.spark.storage.SparkTable; +import org.opensearch.sql.storage.Table; + +/** + * Spark SQL function implementation. + */ +public class SparkSqlFunctionImplementation extends FunctionExpression + implements TableFunctionImplementation { + + private final FunctionName functionName; + private final List arguments; + private final SparkClient sparkClient; + + /** + * Constructor for spark sql function. + * + * @param functionName name of the function + * @param arguments a list of expressions + * @param sparkClient spark client + */ + public SparkSqlFunctionImplementation( + FunctionName functionName, List arguments, SparkClient sparkClient) { + super(functionName, arguments); + this.functionName = functionName; + this.arguments = arguments; + this.sparkClient = sparkClient; + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + throw new UnsupportedOperationException(String.format( + "Spark defined function [%s] is only " + + "supported in SOURCE clause with spark connector catalog", functionName)); + } + + @Override + public ExprType type() { + return ExprCoreType.STRUCT; + } + + @Override + public String toString() { + List args = arguments.stream() + .map(arg -> String.format("%s=%s", + ((NamedArgumentExpression) arg).getArgName(), + ((NamedArgumentExpression) arg).getValue().toString())) + .collect(Collectors.toList()); + return String.format("%s(%s)", functionName, String.join(", ", args)); + } + + @Override + public Table applyArguments() { + return new SparkTable(sparkClient, buildQueryFromSqlFunction(arguments)); + } + + /** + * This method builds a spark query request. + * + * @param arguments spark sql function arguments + * @return spark query request + */ + private SparkQueryRequest buildQueryFromSqlFunction(List arguments) { + + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + arguments.forEach(arg -> { + String argName = ((NamedArgumentExpression) arg).getArgName(); + Expression argValue = ((NamedArgumentExpression) arg).getValue(); + ExprValue literalValue = argValue.valueOf(); + if (argName.equals(QUERY)) { + sparkQueryRequest.setSql((String) literalValue.value()); + } else { + throw new ExpressionEvaluationException( + String.format("Invalid Function Argument:%s", argName)); + } + }); + return sparkQueryRequest; + } + +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java b/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java new file mode 100644 index 0000000000..624600e1a8 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/resolver/SparkSqlTableFunctionResolver.java @@ -0,0 +1,74 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.resolver; + +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.ArrayList; +import java.util.List; +import lombok.RequiredArgsConstructor; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.Pair; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.NamedArgumentExpression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.implementation.SparkSqlFunctionImplementation; + +/** + * Function resolver for sql function of spark connector. + */ +@RequiredArgsConstructor +public class SparkSqlTableFunctionResolver implements FunctionResolver { + private final SparkClient sparkClient; + + public static final String SQL = "sql"; + public static final String QUERY = "query"; + + @Override + public Pair resolve(FunctionSignature unresolvedSignature) { + FunctionName functionName = FunctionName.of(SQL); + FunctionSignature functionSignature = + new FunctionSignature(functionName, List.of(STRING)); + final List argumentNames = List.of(QUERY); + + FunctionBuilder functionBuilder = (functionProperties, arguments) -> { + Boolean argumentsPassedByName = arguments.stream() + .noneMatch(arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); + Boolean argumentsPassedByPosition = arguments.stream() + .allMatch(arg -> StringUtils.isEmpty(((NamedArgumentExpression) arg).getArgName())); + if (!(argumentsPassedByName || argumentsPassedByPosition)) { + throw new SemanticCheckException("Arguments should be either passed by name or position"); + } + + if (arguments.size() != argumentNames.size()) { + throw new SemanticCheckException( + String.format("Missing arguments:[%s]", + String.join(",", argumentNames.subList(arguments.size(), argumentNames.size())))); + } + + if (argumentsPassedByPosition) { + List namedArguments = new ArrayList<>(); + for (int i = 0; i < arguments.size(); i++) { + namedArguments.add(new NamedArgumentExpression(argumentNames.get(i), + ((NamedArgumentExpression) arguments.get(i)).getValue())); + } + return new SparkSqlFunctionImplementation(functionName, namedArguments, sparkClient); + } + return new SparkSqlFunctionImplementation(functionName, arguments, sparkClient); + }; + return Pair.of(functionSignature, functionBuilder); + } + + @Override + public FunctionName getFunctionName() { + return FunctionName.of(SQL); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java new file mode 100644 index 0000000000..561f6f2933 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/scan/SparkSqlFunctionTableScanBuilder.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions.scan; + +import lombok.AllArgsConstructor; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * TableScanBuilder for sql function of spark connector. + */ +@AllArgsConstructor +public class SparkSqlFunctionTableScanBuilder extends TableScanBuilder { + + private final SparkClient sparkClient; + + private final SparkQueryRequest sparkQueryRequest; + + @Override + public TableScanOperator build() { + //TODO: return SqlFunctionTableScanOperator + return null; + } + + @Override + public boolean pushDownProject(LogicalProject project) { + return true; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java b/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java new file mode 100644 index 0000000000..bc0944a784 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/request/SparkQueryRequest.java @@ -0,0 +1,21 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.request; + +import lombok.Data; + +/** + * Spark query request. + */ +@Data +public class SparkQueryRequest { + + /** + * SQL. + */ + private String sql; + +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java new file mode 100644 index 0000000000..a5e35ecc4c --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageEngine.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import java.util.Collection; +import java.util.Collections; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.DataSourceSchemaName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver; +import org.opensearch.sql.storage.StorageEngine; +import org.opensearch.sql.storage.Table; + +/** + * Spark storage engine implementation. + */ +@RequiredArgsConstructor +public class SparkStorageEngine implements StorageEngine { + private final SparkClient sparkClient; + + @Override + public Collection getFunctions() { + return Collections.singletonList( + new SparkSqlTableFunctionResolver(sparkClient)); + } + + @Override + public Table getTable(DataSourceSchemaName dataSourceSchemaName, String tableName) { + throw new RuntimeException("Unable to get table from storage engine."); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java new file mode 100644 index 0000000000..e4da29f6b4 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkStorageFactory.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import java.util.Map; +import lombok.RequiredArgsConstructor; +import org.opensearch.client.Client; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.storage.DataSourceFactory; +import org.opensearch.sql.storage.StorageEngine; + +/** + * Storage factory implementation for spark connector. + */ +@RequiredArgsConstructor +public class SparkStorageFactory implements DataSourceFactory { + private final Client client; + private final Settings settings; + + @Override + public DataSourceType getDataSourceType() { + return DataSourceType.SPARK; + } + + @Override + public DataSource createDataSource(DataSourceMetadata metadata) { + return new DataSource( + metadata.getName(), + DataSourceType.SPARK, + getStorageEngine(metadata.getProperties())); + } + + /** + * This function gets spark storage engine. + * + * @param requiredConfig spark config options + * @return spark storage engine object + */ + StorageEngine getStorageEngine(Map requiredConfig) { + SparkClient sparkClient = null; + //TODO: Initialize spark client + return new SparkStorageEngine(sparkClient); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java new file mode 100644 index 0000000000..344db8ab7a --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/storage/SparkTable.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import java.util.HashMap; +import java.util.Map; +import javax.annotation.Nonnull; +import lombok.Getter; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.planner.logical.LogicalPlan; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanBuilder; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.Table; +import org.opensearch.sql.storage.read.TableScanBuilder; + +/** + * Spark table implementation. + * This can be constructed from SparkQueryRequest. + */ +public class SparkTable implements Table { + + private final SparkClient sparkClient; + + @Getter + private final SparkQueryRequest sparkQueryRequest; + + /** + * Constructor for entire Sql Request. + */ + public SparkTable(SparkClient sparkService, + @Nonnull SparkQueryRequest sparkQueryRequest) { + this.sparkClient = sparkService; + this.sparkQueryRequest = sparkQueryRequest; + } + + @Override + public boolean exists() { + throw new UnsupportedOperationException( + "Exists operation is not supported in spark datasource"); + } + + @Override + public void create(Map schema) { + throw new UnsupportedOperationException( + "Create operation is not supported in spark datasource"); + } + + @Override + public Map getFieldTypes() { + return new HashMap<>(); + } + + @Override + public PhysicalPlan implement(LogicalPlan plan) { + //TODO: Add plan + return null; + } + + @Override + public TableScanBuilder createScanBuilder() { + return new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java new file mode 100644 index 0000000000..33c65ff278 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionImplementationTest.java @@ -0,0 +1,77 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.implementation.SparkSqlFunctionImplementation; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.spark.storage.SparkTable; + +@ExtendWith(MockitoExtension.class) +public class SparkSqlFunctionImplementationTest { + @Mock + private SparkClient client; + + @Test + void testValueOfAndTypeToString() { + FunctionName functionName = new FunctionName("sql"); + List namedArgumentExpressionList + = List.of(DSL.namedArgument("query", DSL.literal("select 1"))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation + = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + UnsupportedOperationException exception = assertThrows(UnsupportedOperationException.class, + () -> sparkSqlFunctionImplementation.valueOf()); + assertEquals("Spark defined function [sql] is only " + + "supported in SOURCE clause with spark connector catalog", exception.getMessage()); + assertEquals("sql(query=\"select 1\")", + sparkSqlFunctionImplementation.toString()); + assertEquals(ExprCoreType.STRUCT, sparkSqlFunctionImplementation.type()); + } + + + @Test + void testApplyArguments() { + FunctionName functionName = new FunctionName("sql"); + List namedArgumentExpressionList + = List.of(DSL.namedArgument("query", DSL.literal("select 1"))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation + = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + SparkTable sparkTable + = (SparkTable) sparkSqlFunctionImplementation.applyArguments(); + assertNotNull(sparkTable.getSparkQueryRequest()); + SparkQueryRequest sparkQueryRequest + = sparkTable.getSparkQueryRequest(); + assertEquals("select 1", sparkQueryRequest.getSql()); + } + + @Test + void testApplyArgumentsException() { + FunctionName functionName = new FunctionName("sql"); + List namedArgumentExpressionList + = List.of(DSL.namedArgument("query", DSL.literal("select 1")), + DSL.namedArgument("tmp", DSL.literal(12345))); + SparkSqlFunctionImplementation sparkSqlFunctionImplementation + = new SparkSqlFunctionImplementation(functionName, namedArgumentExpressionList, client); + ExpressionEvaluationException exception = assertThrows(ExpressionEvaluationException.class, + () -> sparkSqlFunctionImplementation.applyArguments()); + assertEquals("Invalid Function Argument:tmp", exception.getMessage()); + } + +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java new file mode 100644 index 0000000000..f5fb0983cc --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanBuilderTest.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.opensearch.sql.planner.logical.LogicalProject; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanBuilder; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.TableScanOperator; + +public class SparkSqlFunctionTableScanBuilderTest { + @Mock + private SparkClient sparkClient; + + @Mock + private LogicalProject logicalProject; + + @Test + void testBuild() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql("select 1"); + + SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder + = new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); + TableScanOperator sqlFunctionTableScanOperator + = sparkSqlFunctionTableScanBuilder.build(); + Assertions.assertNull(sqlFunctionTableScanOperator); + } + + @Test + void testPushProject() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql("select 1"); + + SparkSqlFunctionTableScanBuilder sparkSqlFunctionTableScanBuilder + = new SparkSqlFunctionTableScanBuilder(sparkClient, sparkQueryRequest); + Assertions.assertTrue(sparkSqlFunctionTableScanBuilder.pushDownProject(logicalProject)); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java new file mode 100644 index 0000000000..491e9bbd73 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlTableFunctionResolverTest.java @@ -0,0 +1,140 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.functions; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.opensearch.sql.data.type.ExprCoreType.STRING; + +import java.util.List; +import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.FunctionBuilder; +import org.opensearch.sql.expression.function.FunctionName; +import org.opensearch.sql.expression.function.FunctionProperties; +import org.opensearch.sql.expression.function.FunctionSignature; +import org.opensearch.sql.expression.function.TableFunctionImplementation; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.implementation.SparkSqlFunctionImplementation; +import org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.spark.storage.SparkTable; + +@ExtendWith(MockitoExtension.class) +public class SparkSqlTableFunctionResolverTest { + @Mock + private SparkClient client; + + @Mock + private FunctionProperties functionProperties; + + @Test + void testResolve() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver + = new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions + = List.of(DSL.namedArgument("query", DSL.literal("select 1"))); + FunctionSignature functionSignature = new FunctionSignature(functionName, expressions + .stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution + = sqlTableFunctionResolver.resolve(functionSignature); + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + FunctionBuilder functionBuilder = resolution.getValue(); + TableFunctionImplementation functionImplementation + = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); + assertTrue(functionImplementation instanceof SparkSqlFunctionImplementation); + SparkTable sparkTable + = (SparkTable) functionImplementation.applyArguments(); + assertNotNull(sparkTable.getSparkQueryRequest()); + SparkQueryRequest sparkQueryRequest = + sparkTable.getSparkQueryRequest(); + assertEquals("select 1", sparkQueryRequest.getSql()); + } + + @Test + void testArgumentsPassedByPosition() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver + = new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions + = List.of(DSL.namedArgument(null, DSL.literal("select 1"))); + FunctionSignature functionSignature = new FunctionSignature(functionName, expressions + .stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution + = sqlTableFunctionResolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + FunctionBuilder functionBuilder = resolution.getValue(); + TableFunctionImplementation functionImplementation + = (TableFunctionImplementation) functionBuilder.apply(functionProperties, expressions); + assertTrue(functionImplementation instanceof SparkSqlFunctionImplementation); + SparkTable sparkTable + = (SparkTable) functionImplementation.applyArguments(); + assertNotNull(sparkTable.getSparkQueryRequest()); + SparkQueryRequest sparkQueryRequest = + sparkTable.getSparkQueryRequest(); + assertEquals("select 1", sparkQueryRequest.getSql()); + } + + @Test + void testMixedArgumentTypes() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver + = new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions + = List.of(DSL.namedArgument("query", DSL.literal("select 1")), + DSL.namedArgument(null, DSL.literal(12345))); + FunctionSignature functionSignature = new FunctionSignature(functionName, expressions + .stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution + = sqlTableFunctionResolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + SemanticCheckException exception = assertThrows(SemanticCheckException.class, + () -> resolution.getValue().apply(functionProperties, expressions)); + + assertEquals("Arguments should be either passed by name or position", exception.getMessage()); + } + + @Test + void testWrongArgumentsSizeWhenPassedByName() { + SparkSqlTableFunctionResolver sqlTableFunctionResolver + = new SparkSqlTableFunctionResolver(client); + FunctionName functionName = FunctionName.of("sql"); + List expressions + = List.of(); + FunctionSignature functionSignature = new FunctionSignature(functionName, expressions + .stream().map(Expression::type).collect(Collectors.toList())); + Pair resolution + = sqlTableFunctionResolver.resolve(functionSignature); + + assertEquals(functionName, resolution.getKey().getFunctionName()); + assertEquals(functionName, sqlTableFunctionResolver.getFunctionName()); + assertEquals(List.of(STRING), resolution.getKey().getParamTypeList()); + SemanticCheckException exception = assertThrows(SemanticCheckException.class, + () -> resolution.getValue().apply(functionProperties, expressions)); + + assertEquals("Missing arguments:[query]", exception.getMessage()); + } + +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java new file mode 100644 index 0000000000..7adcc725fa --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageEngineTest.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.Collection; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.DataSourceSchemaName; +import org.opensearch.sql.expression.function.FunctionResolver; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.resolver.SparkSqlTableFunctionResolver; +import org.opensearch.sql.storage.Table; + +@ExtendWith(MockitoExtension.class) +public class SparkStorageEngineTest { + @Mock + private SparkClient client; + + @Test + public void getFunctions() { + SparkStorageEngine engine = new SparkStorageEngine(client); + Collection functionResolverCollection + = engine.getFunctions(); + assertNotNull(functionResolverCollection); + assertEquals(1, functionResolverCollection.size()); + assertTrue( + functionResolverCollection.iterator().next() instanceof SparkSqlTableFunctionResolver); + } + + @Test + public void getTable() { + SparkStorageEngine engine = new SparkStorageEngine(client); + RuntimeException exception = assertThrows(RuntimeException.class, + () -> engine.getTable(new DataSourceSchemaName("spark", "default"), "")); + assertEquals("Unable to get table from storage engine.", exception.getMessage()); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java new file mode 100644 index 0000000000..4142cfe355 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkStorageFactoryTest.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import java.util.HashMap; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.client.Client; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.datasource.model.DataSource; +import org.opensearch.sql.datasource.model.DataSourceMetadata; +import org.opensearch.sql.datasource.model.DataSourceType; +import org.opensearch.sql.storage.StorageEngine; + +@ExtendWith(MockitoExtension.class) +public class SparkStorageFactoryTest { + @Mock + private Settings settings; + + @Mock + private Client client; + + @Test + void testGetConnectorType() { + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + Assertions.assertEquals( + DataSourceType.SPARK, sparkStorageFactory.getDataSourceType()); + } + + @Test + @SneakyThrows + void testGetStorageEngine() { + SparkStorageFactory sparkStorageFactory = new SparkStorageFactory(client, settings); + StorageEngine storageEngine + = sparkStorageFactory.getStorageEngine(new HashMap<>()); + Assertions.assertTrue(storageEngine instanceof SparkStorageEngine); + } + + @Test + void createDataSourceSuccessWithLocalhost() { + DataSourceMetadata metadata = new DataSourceMetadata(); + metadata.setName("spark"); + metadata.setConnector(DataSourceType.SPARK); + metadata.setProperties(new HashMap<>()); + + DataSource dataSource = new SparkStorageFactory(client, settings).createDataSource(metadata); + Assertions.assertTrue(dataSource.getStorageEngine() instanceof SparkStorageEngine); + } + +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java new file mode 100644 index 0000000000..d3487d65c1 --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/storage/SparkTableTest.java @@ -0,0 +1,87 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.project; +import static org.opensearch.sql.planner.logical.LogicalPlanDSL.relation; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import lombok.SneakyThrows; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.expression.NamedExpression; +import org.opensearch.sql.planner.physical.PhysicalPlan; +import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanBuilder; +import org.opensearch.sql.spark.request.SparkQueryRequest; +import org.opensearch.sql.storage.read.TableScanBuilder; + +@ExtendWith(MockitoExtension.class) +public class SparkTableTest { + @Mock + private SparkClient client; + + @Test + void testUnsupportedOperation() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + SparkTable sparkTable = + new SparkTable(client, sparkQueryRequest); + + assertThrows(UnsupportedOperationException.class, sparkTable::exists); + assertThrows(UnsupportedOperationException.class, + () -> sparkTable.create(Collections.emptyMap())); + } + + @Test + void testCreateScanBuilderWithSqlTableFunction() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql("select 1"); + SparkTable sparkTable = + new SparkTable(client, sparkQueryRequest); + TableScanBuilder tableScanBuilder = sparkTable.createScanBuilder(); + Assertions.assertNotNull(tableScanBuilder); + Assertions.assertTrue(tableScanBuilder instanceof SparkSqlFunctionTableScanBuilder); + } + + @Test + @SneakyThrows + void testGetFieldTypesFromSparkQueryRequest() { + SparkTable sparkTable + = new SparkTable(client, new SparkQueryRequest()); + Map expectedFieldTypes = new HashMap<>(); + Map fieldTypes = sparkTable.getFieldTypes(); + + assertEquals(expectedFieldTypes, fieldTypes); + verifyNoMoreInteractions(client); + assertNotNull(sparkTable.getSparkQueryRequest()); + } + + @Test + void testImplementWithSqlFunction() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql("select 1"); + SparkTable sparkTable = + new SparkTable(client, sparkQueryRequest); + List finalProjectList = new ArrayList<>(); + PhysicalPlan plan = sparkTable.implement( + project(relation("sql", sparkTable), + finalProjectList, null)); + assertNull(plan); + } +}