From e59bf75d701baa70df88fd6b89f5d9f194004f63 Mon Sep 17 00:00:00 2001 From: Peng Huo Date: Wed, 31 Jan 2024 17:44:17 -0800 Subject: [PATCH] Add SparkDataType as wrapper for unmapped spark data type (#2492) * Add SparkDataType as wrapper for unmapped spark data type Signed-off-by: Peng Huo * add IT for parsing explain query reponse Signed-off-by: Peng Huo --------- Signed-off-by: Peng Huo --- .../sql/spark/data/type/SparkDataType.java | 24 ++ .../sql/spark/data/value/SparkExprValue.java | 40 +++ ...DefaultSparkSqlFunctionResponseHandle.java | 74 +++--- .../AsyncQueryGetResultSpecTest.java | 246 +++++++++++++++++- .../spark/data/value/SparkExprValueTest.java | 28 ++ ...SparkSqlFunctionTableScanOperatorTest.java | 79 +++++- .../src/test/resources/invalid_data_type.json | 12 - spark/src/test/resources/spark_data_type.json | 13 + 8 files changed, 460 insertions(+), 56 deletions(-) create mode 100644 spark/src/main/java/org/opensearch/sql/spark/data/type/SparkDataType.java create mode 100644 spark/src/main/java/org/opensearch/sql/spark/data/value/SparkExprValue.java create mode 100644 spark/src/test/java/org/opensearch/sql/spark/data/value/SparkExprValueTest.java delete mode 100644 spark/src/test/resources/invalid_data_type.json create mode 100644 spark/src/test/resources/spark_data_type.json diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/type/SparkDataType.java b/spark/src/main/java/org/opensearch/sql/spark/data/type/SparkDataType.java new file mode 100644 index 0000000000..5d36492d72 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/data/type/SparkDataType.java @@ -0,0 +1,24 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.data.type; + +import lombok.EqualsAndHashCode; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.data.type.ExprType; + +/** Wrapper of spark data type */ +@EqualsAndHashCode +@RequiredArgsConstructor +public class SparkDataType implements ExprType { + + /** Spark datatype name. */ + private final String typeName; + + @Override + public String typeName() { + return typeName; + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/data/value/SparkExprValue.java b/spark/src/main/java/org/opensearch/sql/spark/data/value/SparkExprValue.java new file mode 100644 index 0000000000..1d5f6296a7 --- /dev/null +++ b/spark/src/main/java/org/opensearch/sql/spark/data/value/SparkExprValue.java @@ -0,0 +1,40 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.data.value; + +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.data.model.AbstractExprValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.spark.data.type.SparkDataType; + +/** SparkExprValue hold spark query response value. */ +@RequiredArgsConstructor +public class SparkExprValue extends AbstractExprValue { + + private final SparkDataType type; + private final Object value; + + @Override + public Object value() { + return value; + } + + @Override + public ExprType type() { + return type; + } + + @Override + public int compare(ExprValue other) { + throw new UnsupportedOperationException("SparkExprValue is not comparable"); + } + + @Override + public boolean equal(ExprValue other) { + return value.equals(other.value()); + } +} diff --git a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java index 422d1caaf1..8a571d1dda 100644 --- a/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java +++ b/spark/src/main/java/org/opensearch/sql/spark/functions/response/DefaultSparkSqlFunctionResponseHandle.java @@ -19,6 +19,7 @@ import org.opensearch.sql.data.model.ExprFloatValue; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprShortValue; import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprTimestampValue; @@ -27,6 +28,8 @@ import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.data.type.ExprType; import org.opensearch.sql.executor.ExecutionEngine; +import org.opensearch.sql.spark.data.type.SparkDataType; +import org.opensearch.sql.spark.data.value.SparkExprValue; /** Default implementation of SparkSqlFunctionResponseHandle. */ public class DefaultSparkSqlFunctionResponseHandle implements SparkSqlFunctionResponseHandle { @@ -64,30 +67,43 @@ private static LinkedHashMap extractRow( LinkedHashMap linkedHashMap = new LinkedHashMap<>(); for (ExecutionEngine.Schema.Column column : columnList) { ExprType type = column.getExprType(); - if (type == ExprCoreType.BOOLEAN) { - linkedHashMap.put(column.getName(), ExprBooleanValue.of(row.getBoolean(column.getName()))); - } else if (type == ExprCoreType.LONG) { - linkedHashMap.put(column.getName(), new ExprLongValue(row.getLong(column.getName()))); - } else if (type == ExprCoreType.INTEGER) { - linkedHashMap.put(column.getName(), new ExprIntegerValue(row.getInt(column.getName()))); - } else if (type == ExprCoreType.SHORT) { - linkedHashMap.put(column.getName(), new ExprShortValue(row.getInt(column.getName()))); - } else if (type == ExprCoreType.BYTE) { - linkedHashMap.put(column.getName(), new ExprByteValue(row.getInt(column.getName()))); - } else if (type == ExprCoreType.DOUBLE) { - linkedHashMap.put(column.getName(), new ExprDoubleValue(row.getDouble(column.getName()))); - } else if (type == ExprCoreType.FLOAT) { - linkedHashMap.put(column.getName(), new ExprFloatValue(row.getFloat(column.getName()))); - } else if (type == ExprCoreType.DATE) { - // TODO :: correct this to ExprTimestampValue - linkedHashMap.put(column.getName(), new ExprStringValue(row.getString(column.getName()))); - } else if (type == ExprCoreType.TIMESTAMP) { - linkedHashMap.put( - column.getName(), new ExprTimestampValue(row.getString(column.getName()))); - } else if (type == ExprCoreType.STRING) { - linkedHashMap.put(column.getName(), new ExprStringValue(jsonString(row, column.getName()))); + if (!row.has(column.getName())) { + linkedHashMap.put(column.getName(), ExprNullValue.of()); } else { - throw new RuntimeException("Result contains invalid data type"); + if (type == ExprCoreType.BOOLEAN) { + linkedHashMap.put( + column.getName(), ExprBooleanValue.of(row.getBoolean(column.getName()))); + } else if (type == ExprCoreType.LONG) { + linkedHashMap.put(column.getName(), new ExprLongValue(row.getLong(column.getName()))); + } else if (type == ExprCoreType.INTEGER) { + linkedHashMap.put(column.getName(), new ExprIntegerValue(row.getInt(column.getName()))); + } else if (type == ExprCoreType.SHORT) { + linkedHashMap.put(column.getName(), new ExprShortValue(row.getInt(column.getName()))); + } else if (type == ExprCoreType.BYTE) { + linkedHashMap.put(column.getName(), new ExprByteValue(row.getInt(column.getName()))); + } else if (type == ExprCoreType.DOUBLE) { + linkedHashMap.put(column.getName(), new ExprDoubleValue(row.getDouble(column.getName()))); + } else if (type == ExprCoreType.FLOAT) { + linkedHashMap.put(column.getName(), new ExprFloatValue(row.getFloat(column.getName()))); + } else if (type == ExprCoreType.DATE) { + // TODO :: correct this to ExprTimestampValue + linkedHashMap.put(column.getName(), new ExprStringValue(row.getString(column.getName()))); + } else if (type == ExprCoreType.TIMESTAMP) { + linkedHashMap.put( + column.getName(), new ExprTimestampValue(row.getString(column.getName()))); + } else if (type == ExprCoreType.STRING) { + linkedHashMap.put(column.getName(), new ExprStringValue(row.getString(column.getName()))); + } else { + // SparkDataType + Object jsonValue = row.get(column.getName()); + Object value = jsonValue; + if (jsonValue instanceof JSONObject) { + value = ((JSONObject) jsonValue).toMap(); + } else if (jsonValue instanceof JSONArray) { + value = ((JSONArray) jsonValue).toList(); + } + linkedHashMap.put(column.getName(), new SparkExprValue((SparkDataType) type, value)); + } } } @@ -107,8 +123,8 @@ private List getColumnList(JSONArray schema) { return columnList; } - private ExprCoreType getDataType(String sparkDataType) { - switch (sparkDataType) { + private ExprType getDataType(String sparkType) { + switch (sparkType) { case "boolean": return ExprCoreType.BOOLEAN; case "long": @@ -128,18 +144,12 @@ private ExprCoreType getDataType(String sparkDataType) { case "date": return ExprCoreType.TIMESTAMP; case "string": - case "varchar": - case "char": return ExprCoreType.STRING; default: - return ExprCoreType.UNKNOWN; + return new SparkDataType(sparkType); } } - private static String jsonString(JSONObject jsonObject, String key) { - return jsonObject.has(key) ? jsonObject.getString(key) : ""; - } - @Override public boolean hasNext() { return responseIterator.hasNext(); diff --git a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java index bba38693cd..2ddfe77868 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/asyncquery/AsyncQueryGetResultSpecTest.java @@ -21,7 +21,11 @@ import org.junit.Test; import org.opensearch.action.index.IndexRequest; import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.executor.pagination.Cursor; +import org.opensearch.sql.protocol.response.format.JsonResponseFormatter; +import org.opensearch.sql.protocol.response.format.ResponseFormatter; import org.opensearch.sql.spark.asyncquery.model.AsyncQueryExecutionResponse; +import org.opensearch.sql.spark.asyncquery.model.AsyncQueryResult; import org.opensearch.sql.spark.execution.statement.StatementModel; import org.opensearch.sql.spark.execution.statement.StatementState; import org.opensearch.sql.spark.execution.statestore.StateStore; @@ -30,6 +34,7 @@ import org.opensearch.sql.spark.rest.model.CreateAsyncQueryRequest; import org.opensearch.sql.spark.rest.model.CreateAsyncQueryResponse; import org.opensearch.sql.spark.rest.model.LangType; +import org.opensearch.sql.spark.transport.format.AsyncQueryResultResponseFormatter; public class AsyncQueryGetResultSpecTest extends AsyncQueryExecutorServiceSpec { @@ -181,6 +186,217 @@ public void testDropIndexQueryGetResultWithResultDocRefreshDelay() { .assertQueryResults("SUCCESS", ImmutableList.of()); } + @Test + public void testInteractiveQueryResponse() { + createAsyncQuery("SELECT * FROM TABLE") + .withInteraction(InteractionStep::pluginSearchQueryResult) + .assertQueryResults("waiting", null) + .withInteraction( + interaction -> { + interaction.emrJobWriteResultDoc(createResultDoc(interaction.queryId)); + interaction.emrJobUpdateStatementState(StatementState.SUCCESS); + return interaction.pluginSearchQueryResult(); + }) + .assertFormattedQueryResults( + "{\"status\":\"SUCCESS\",\"schema\":[{\"name\":\"1\"," + + "\"type\":\"integer\"}],\"datarows\":[[1]],\"total\":1,\"size\":1}"); + } + + @Test + public void testInteractiveQueryResponseBasicType() { + createAsyncQuery("SELECT * FROM TABLE") + .withInteraction(InteractionStep::pluginSearchQueryResult) + .assertQueryResults("waiting", null) + .withInteraction( + interaction -> { + interaction.emrJobWriteResultDoc( + createResultDoc( + interaction.queryId, + ImmutableList.of( + "{'column1': 'value1', 'column2': 123, 'column3': true}", + "{'column1': 'value2', 'column2': 456, 'column3': false}"), + ImmutableList.of( + "{'column_name': 'column1', 'data_type': 'string'}", + "{'column_name': 'column2', 'data_type': 'integer'}", + "{'column_name': 'column3', 'data_type': 'boolean'}"))); + interaction.emrJobUpdateStatementState(StatementState.SUCCESS); + return interaction.pluginSearchQueryResult(); + }) + .assertFormattedQueryResults( + "{\"status\":\"SUCCESS\",\"schema\":[{\"name\":\"column1\",\"type\":\"string\"},{\"name\":\"column2\",\"type\":\"integer\"},{\"name\":\"column3\",\"type\":\"boolean\"}],\"datarows\":[[\"value1\",123,true],[\"value2\",456,false]],\"total\":2,\"size\":2}"); + } + + @Test + public void testInteractiveQueryResponseJsonArray() { + createAsyncQuery("SELECT * FROM TABLE") + .withInteraction(InteractionStep::pluginSearchQueryResult) + .assertQueryResults("waiting", null) + .withInteraction( + interaction -> { + interaction.emrJobWriteResultDoc( + createResultDoc( + interaction.queryId, + ImmutableList.of( + "{ 'attributes': [{ 'key': 'telemetry.sdk.language', 'value': {" + + " 'stringValue': 'python' }}, { 'key': 'telemetry.sdk.name'," + + " 'value': { 'stringValue': 'opentelemetry' }}, { 'key':" + + " 'telemetry.sdk.version', 'value': { 'stringValue': '1.19.0' }}, {" + + " 'key': 'service.namespace', 'value': { 'stringValue':" + + " 'opentelemetry-demo' }}, { 'key': 'service.name', 'value': {" + + " 'stringValue': 'recommendationservice' }}, { 'key':" + + " 'telemetry.auto.version', 'value': { 'stringValue': '0.40b0'" + + " }}]}"), + ImmutableList.of("{'column_name':'attributes','data_type':'array'}"))); + interaction.emrJobUpdateStatementState(StatementState.SUCCESS); + return interaction.pluginSearchQueryResult(); + }) + .assertFormattedQueryResults( + "{\"status\":\"SUCCESS\",\"schema\":[{\"name\":\"attributes\",\"type\":\"array\"}],\"datarows\":[[[{\"value\":{\"stringValue\":\"python\"},\"key\":\"telemetry.sdk.language\"},{\"value\":{\"stringValue\":\"opentelemetry\"},\"key\":\"telemetry.sdk.name\"},{\"value\":{\"stringValue\":\"1.19.0\"},\"key\":\"telemetry.sdk.version\"},{\"value\":{\"stringValue\":\"opentelemetry-demo\"},\"key\":\"service.namespace\"},{\"value\":{\"stringValue\":\"recommendationservice\"},\"key\":\"service.name\"},{\"value\":{\"stringValue\":\"0.40b0\"},\"key\":\"telemetry.auto.version\"}]]],\"total\":1,\"size\":1}"); + } + + @Test + public void testInteractiveQueryResponseJsonNested() { + createAsyncQuery("SELECT * FROM TABLE") + .withInteraction(InteractionStep::pluginSearchQueryResult) + .assertQueryResults("waiting", null) + .withInteraction( + interaction -> { + interaction.emrJobWriteResultDoc( + createResultDoc( + interaction.queryId, + ImmutableList.of( + "{\n" + + " 'resourceSpans': {\n" + + " 'scopeSpans': {\n" + + " 'spans': {\n" + + " 'key': 'rpc.system',\n" + + " 'value': {\n" + + " 'stringValue': 'grpc'\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"), + ImmutableList.of("{'column_name':'resourceSpans','data_type':'struct'}"))); + interaction.emrJobUpdateStatementState(StatementState.SUCCESS); + return interaction.pluginSearchQueryResult(); + }) + .assertFormattedQueryResults( + "{\"status\":\"SUCCESS\",\"schema\":[{\"name\":\"resourceSpans\",\"type\":\"struct\"}],\"datarows\":[[{\"scopeSpans\":{\"spans\":{\"value\":{\"stringValue\":\"grpc\"},\"key\":\"rpc.system\"}}}]],\"total\":1,\"size\":1}"); + } + + @Test + public void testInteractiveQueryResponseJsonNestedObjectArray() { + createAsyncQuery("SELECT * FROM TABLE") + .withInteraction(InteractionStep::pluginSearchQueryResult) + .assertQueryResults("waiting", null) + .withInteraction( + interaction -> { + interaction.emrJobWriteResultDoc( + createResultDoc( + interaction.queryId, + ImmutableList.of( + "{\n" + + " 'resourceSpans': \n" + + " {\n" + + " 'scopeSpans': \n" + + " {\n" + + " 'spans': \n" + + " [\n" + + " {\n" + + " 'attribute': {\n" + + " 'key': 'rpc.system',\n" + + " 'value': {\n" + + " 'stringValue': 'grpc'\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " 'attribute': {\n" + + " 'key': 'rpc.system',\n" + + " 'value': {\n" + + " 'stringValue': 'grpc'\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + "}"), + ImmutableList.of("{'column_name':'resourceSpans','data_type':'struct'}"))); + interaction.emrJobUpdateStatementState(StatementState.SUCCESS); + return interaction.pluginSearchQueryResult(); + }) + .assertFormattedQueryResults( + "{\"status\":\"SUCCESS\",\"schema\":[{\"name\":\"resourceSpans\",\"type\":\"struct\"}],\"datarows\":[[{\"scopeSpans\":{\"spans\":[{\"attribute\":{\"value\":{\"stringValue\":\"grpc\"},\"key\":\"rpc.system\"}},{\"attribute\":{\"value\":{\"stringValue\":\"grpc\"},\"key\":\"rpc.system\"}}]}}]],\"total\":1,\"size\":1}"); + } + + @Test + public void testExplainResponse() { + createAsyncQuery("EXPLAIN SELECT * FROM TABLE") + .withInteraction(InteractionStep::pluginSearchQueryResult) + .assertQueryResults("waiting", null) + .withInteraction( + interaction -> { + interaction.emrJobWriteResultDoc( + createResultDoc( + interaction.queryId, + ImmutableList.of("{'plan':'== Physical Plan ==\\nAdaptiveSparkPlan'}"), + ImmutableList.of("{'column_name':'plan','data_type':'string'}"))); + interaction.emrJobUpdateStatementState(StatementState.SUCCESS); + return interaction.pluginSearchQueryResult(); + }) + .assertFormattedQueryResults( + "{\"status\":\"SUCCESS\",\"schema\":[{\"name\":\"plan\",\"type\":\"string\"}],\"datarows\":[[\"==" + + " Physical Plan ==\\n" + + "AdaptiveSparkPlan\"]],\"total\":1,\"size\":1}"); + } + + @Test + public void testInteractiveQueryEmptyResponseIssue2367() { + createAsyncQuery("SELECT * FROM TABLE") + .withInteraction(InteractionStep::pluginSearchQueryResult) + .assertQueryResults("waiting", null) + .withInteraction( + interaction -> { + interaction.emrJobWriteResultDoc( + createResultDoc( + interaction.queryId, + ImmutableList.of( + "{'srcPort':20641}", + "{'srcPort':20641}", + "{}", + "{}", + "{'srcPort':20641}", + "{'srcPort':20641}"), + ImmutableList.of("{'column_name':'srcPort','data_type':'long'}"))); + interaction.emrJobUpdateStatementState(StatementState.SUCCESS); + return interaction.pluginSearchQueryResult(); + }) + .assertFormattedQueryResults( + "{\"status\":\"SUCCESS\",\"schema\":[{\"name\":\"srcPort\",\"type\":\"long\"}],\"datarows\":[[20641],[20641],[null],[null],[20641],[20641]],\"total\":6,\"size\":6}"); + } + + @Test + public void testInteractiveQueryArrayResponseIssue2367() { + createAsyncQuery("SELECT * FROM TABLE") + .withInteraction(InteractionStep::pluginSearchQueryResult) + .assertQueryResults("waiting", null) + .withInteraction( + interaction -> { + interaction.emrJobWriteResultDoc( + createResultDoc( + interaction.queryId, + ImmutableList.of( + "{'resourceSpans':[{'resource':{'attributes':[{'key':'telemetry.sdk.language','value':{'stringValue':'python'}},{'key':'telemetry.sdk.name','value':{'stringValue':'opentelemetry'}}]},'scopeSpans':[{'scope':{'name':'opentelemetry.instrumentation.grpc','version':'0.40b0'},'spans':[{'attributes':[{'key':'rpc.system','value':{'stringValue':'grpc'}},{'key':'rpc.grpc.status_code','value':{'intValue':'0'}}],'kind':3},{'attributes':[{'key':'rpc.system','value':{'stringValue':'grpc'}},{'key':'rpc.grpc.status_code','value':{'intValue':'0'}}],'kind':3}]}]}]}"), + ImmutableList.of("{'column_name':'resourceSpans','data_type':'array'}"))); + interaction.emrJobUpdateStatementState(StatementState.SUCCESS); + return interaction.pluginSearchQueryResult(); + }) + .assertFormattedQueryResults( + "{\"status\":\"SUCCESS\",\"schema\":[{\"name\":\"resourceSpans\",\"type\":\"array\"}],\"datarows\":[[[{\"resource\":{\"attributes\":[{\"value\":{\"stringValue\":\"python\"},\"key\":\"telemetry.sdk.language\"},{\"value\":{\"stringValue\":\"opentelemetry\"},\"key\":\"telemetry.sdk.name\"}]},\"scopeSpans\":[{\"spans\":[{\"kind\":3,\"attributes\":[{\"value\":{\"stringValue\":\"grpc\"},\"key\":\"rpc.system\"},{\"value\":{\"intValue\":\"0\"},\"key\":\"rpc.grpc.status_code\"}]},{\"kind\":3,\"attributes\":[{\"value\":{\"stringValue\":\"grpc\"},\"key\":\"rpc.system\"},{\"value\":{\"intValue\":\"0\"},\"key\":\"rpc.grpc.status_code\"}]}],\"scope\":{\"name\":\"opentelemetry.instrumentation.grpc\",\"version\":\"0.40b0\"}}]}]]],\"total\":1,\"size\":1}"); + } + private AssertionHelper createAsyncQuery(String query) { return new AssertionHelper(query, new LocalEMRSClient()); } @@ -231,6 +447,24 @@ AssertionHelper assertQueryResults(String status, List data) { assertEquals(data, results.getResults()); return this; } + + AssertionHelper assertFormattedQueryResults(String expected) { + AsyncQueryExecutionResponse results = + queryService.getAsyncQueryResults(createQueryResponse.getQueryId()); + + ResponseFormatter formatter = + new AsyncQueryResultResponseFormatter(JsonResponseFormatter.Style.COMPACT); + assertEquals( + expected, + formatter.format( + new AsyncQueryResult( + results.getStatus(), + results.getSchema(), + results.getResults(), + Cursor.None, + results.getError()))); + return this; + } } /** Define an interaction between PPL plugin and EMR-S job. */ @@ -299,9 +533,17 @@ private Map createEmptyResultDoc(String queryId) { } private Map createResultDoc(String queryId) { + return createResultDoc( + queryId, + ImmutableList.of("{'1':1}"), + ImmutableList.of("{'column_name" + "':'1','data_type':'integer'}")); + } + + private Map createResultDoc( + String queryId, List result, List schema) { Map document = new HashMap<>(); - document.put("result", ImmutableList.of("{'1':1}")); - document.put("schema", ImmutableList.of("{'column_name':'1','data_type':'integer'}")); + document.put("result", result); + document.put("schema", schema); document.put("jobRunId", "XXX"); document.put("applicationId", "YYY"); document.put("dataSourceName", DATASOURCE); diff --git a/spark/src/test/java/org/opensearch/sql/spark/data/value/SparkExprValueTest.java b/spark/src/test/java/org/opensearch/sql/spark/data/value/SparkExprValueTest.java new file mode 100644 index 0000000000..e58f240f5c --- /dev/null +++ b/spark/src/test/java/org/opensearch/sql/spark/data/value/SparkExprValueTest.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.spark.data.value; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.Test; +import org.opensearch.sql.spark.data.type.SparkDataType; + +class SparkExprValueTest { + @Test + public void type() { + assertEquals( + new SparkDataType("char"), new SparkExprValue(new SparkDataType("char"), "str").type()); + } + + @Test + public void unsupportedCompare() { + SparkDataType type = new SparkDataType("char"); + + assertThrows( + UnsupportedOperationException.class, + () -> new SparkExprValue(type, "str").compare(new SparkExprValue(type, "str"))); + } +} diff --git a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java index 188cd695a3..d44e3d271a 100644 --- a/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java +++ b/spark/src/test/java/org/opensearch/sql/spark/functions/SparkSqlFunctionTableScanOperatorTest.java @@ -10,6 +10,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.nullValue; import static org.opensearch.sql.data.model.ExprValueUtils.stringValue; import static org.opensearch.sql.spark.constants.TestConstants.QUERY; import static org.opensearch.sql.spark.utils.TestUtils.getJson; @@ -18,6 +19,7 @@ import java.util.ArrayList; import java.util.LinkedHashMap; import lombok.SneakyThrows; +import org.json.JSONArray; import org.json.JSONObject; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; @@ -31,6 +33,7 @@ import org.opensearch.sql.data.model.ExprFloatValue; import org.opensearch.sql.data.model.ExprIntegerValue; import org.opensearch.sql.data.model.ExprLongValue; +import org.opensearch.sql.data.model.ExprNullValue; import org.opensearch.sql.data.model.ExprShortValue; import org.opensearch.sql.data.model.ExprStringValue; import org.opensearch.sql.data.model.ExprTimestampValue; @@ -38,6 +41,8 @@ import org.opensearch.sql.data.type.ExprCoreType; import org.opensearch.sql.executor.ExecutionEngine; import org.opensearch.sql.spark.client.SparkClient; +import org.opensearch.sql.spark.data.type.SparkDataType; +import org.opensearch.sql.spark.data.value.SparkExprValue; import org.opensearch.sql.spark.functions.scan.SparkSqlFunctionTableScanOperator; import org.opensearch.sql.spark.request.SparkQueryRequest; @@ -134,7 +139,7 @@ void testQueryResponseAllTypes() { put("timestamp", new ExprDateValue("2023-07-01 10:31:30")); put("date", new ExprTimestampValue("2023-07-01 10:31:30")); put("string", new ExprStringValue("ABC")); - put("char", new ExprStringValue("A")); + put("char", new SparkExprValue(new SparkDataType("char"), "A")); } }); assertEquals(firstRow, sparkSqlFunctionTableScanOperator.next()); @@ -143,19 +148,31 @@ void testQueryResponseAllTypes() { @Test @SneakyThrows - void testQueryResponseInvalidDataType() { + void testQueryResponseSparkDataType() { SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); sparkQueryRequest.setSql(QUERY); SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); - when(sparkClient.sql(any())).thenReturn(new JSONObject(getJson("invalid_data_type.json"))); - - RuntimeException exception = - Assertions.assertThrows( - RuntimeException.class, () -> sparkSqlFunctionTableScanOperator.open()); - Assertions.assertEquals("Result contains invalid data type", exception.getMessage()); + when(sparkClient.sql(any())).thenReturn(new JSONObject(getJson("spark_data_type.json"))); + sparkSqlFunctionTableScanOperator.open(); + assertEquals( + new ExprTupleValue( + new LinkedHashMap<>() { + { + put( + "struct_column", + new SparkExprValue( + new SparkDataType("struct"), + new JSONObject("{\"struct_value\":\"value\"}}").toMap())); + put( + "array_column", + new SparkExprValue( + new SparkDataType("array"), new JSONArray("[1,2]").toList())); + } + }), + sparkSqlFunctionTableScanOperator.next()); } @Test @@ -194,7 +211,7 @@ void issue2210() { { put("col_name", stringValue("day")); put("data_type", stringValue("int")); - put("comment", stringValue("")); + put("comment", nullValue()); } }), sparkSqlFunctionTableScanOperator.next()); @@ -224,10 +241,52 @@ void issue2210() { { put("col_name", stringValue("day")); put("data_type", stringValue("int")); - put("comment", stringValue("")); + put("comment", nullValue()); } }), sparkSqlFunctionTableScanOperator.next()); Assertions.assertFalse(sparkSqlFunctionTableScanOperator.hasNext()); } + + @Test + @SneakyThrows + public void issue2367MissingFields() { + SparkQueryRequest sparkQueryRequest = new SparkQueryRequest(); + sparkQueryRequest.setSql(QUERY); + + SparkSqlFunctionTableScanOperator sparkSqlFunctionTableScanOperator = + new SparkSqlFunctionTableScanOperator(sparkClient, sparkQueryRequest); + + when(sparkClient.sql(any())) + .thenReturn( + new JSONObject( + "{\n" + + " \"data\": {\n" + + " \"result\": [\n" + + " \"{}\",\n" + + " \"{'srcPort':20641}\"\n" + + " ],\n" + + " \"schema\": [\n" + + " \"{'column_name':'srcPort','data_type':'long'}\"\n" + + " ]\n" + + " }\n" + + "}")); + sparkSqlFunctionTableScanOperator.open(); + assertEquals( + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("srcPort", ExprNullValue.of()); + } + }), + sparkSqlFunctionTableScanOperator.next()); + assertEquals( + new ExprTupleValue( + new LinkedHashMap<>() { + { + put("srcPort", new ExprLongValue(20641L)); + } + }), + sparkSqlFunctionTableScanOperator.next()); + } } diff --git a/spark/src/test/resources/invalid_data_type.json b/spark/src/test/resources/invalid_data_type.json deleted file mode 100644 index 0eb08423c8..0000000000 --- a/spark/src/test/resources/invalid_data_type.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "data": { - "result": [ - "{'struct_column':'struct_value'}" - ], - "schema": [ - "{'column_name':'struct_column','data_type':'struct'}" - ], - "stepId": "s-123456789", - "applicationId": "application-abc" - } -} diff --git a/spark/src/test/resources/spark_data_type.json b/spark/src/test/resources/spark_data_type.json new file mode 100644 index 0000000000..79bd047f27 --- /dev/null +++ b/spark/src/test/resources/spark_data_type.json @@ -0,0 +1,13 @@ +{ + "data": { + "result": [ + "{'struct_column':{'struct_value':'value'},'array_column':[1,2]}" + ], + "schema": [ + "{'column_name':'struct_column','data_type':'struct'}", + "{'column_name':'array_column','data_type':'array'}" + ], + "stepId": "s-123456789", + "applicationId": "application-abc" + } +}