diff --git a/coral-spark/build.gradle b/coral-spark/build.gradle index f3fbf025c..6c1e27a0f 100644 --- a/coral-spark/build.gradle +++ b/coral-spark/build.gradle @@ -3,10 +3,15 @@ apply from: "spark_itest.gradle" dependencies { compile project(':coral-hive') compile project(':coral-schema') + compile('org.apache.spark:spark-sql_2.13:3.2.0') { + exclude group: 'com.fasterxml.jackson', module: 'jackson-bom' + exclude group: 'org.apache.avro', module: 'avro-mapred' + } compileOnly deps.'spark'.'sql' + testCompile project(':coral-trino') testCompile(deps.'hive'.'hive-exec-core') { exclude group: 'org.apache.avro', module: 'avro-tools' // These exclusions are required to prevent duplicate classes since we include diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/AddExplicitAlias.java b/coral-spark/src/main/java/com/linkedin/coral/spark/AddExplicitAlias.java index 97807a1ef..6545966a0 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/AddExplicitAlias.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/AddExplicitAlias.java @@ -1,5 +1,5 @@ /** - * Copyright 2022 LinkedIn Corporation. All rights reserved. + * Copyright 2022-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -11,7 +11,12 @@ import com.google.common.base.Preconditions; -import org.apache.calcite.sql.*; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.util.SqlShuttle; diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java index 8ffba3f71..164b0f041 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSqlNodeToSparkSqlNodeConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSqlNodeToSparkSqlNodeConverter.java index af3e4a0f6..2982f8f9e 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSqlNodeToSparkSqlNodeConverter.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSqlNodeToSparkSqlNodeConverter.java @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2022-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java index c8a86d35c..bc321338d 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java @@ -1,5 +1,5 @@ /** - * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java index 97f2934a6..cc3a4fd3a 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java index 5cfa3fb6a..b50b73ddb 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkRelInfo.java b/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkRelInfo.java index 7e20bed7a..0b2a02def 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkRelInfo.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkRelInfo.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkUDFInfo.java b/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkUDFInfo.java index 1d229daa3..910a510a7 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkUDFInfo.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkUDFInfo.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/dialect/SparkSqlDialect.java b/coral-spark/src/main/java/com/linkedin/coral/spark/dialect/SparkSqlDialect.java index faf330749..5a6d71b5e 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/dialect/SparkSqlDialect.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/dialect/SparkSqlDialect.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/exceptions/UnsupportedUDFException.java b/coral-spark/src/main/java/com/linkedin/coral/spark/exceptions/UnsupportedUDFException.java index 4a33e8dba..23b861305 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/exceptions/UnsupportedUDFException.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/exceptions/UnsupportedUDFException.java @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 LinkedIn Corporation. All rights reserved. + * Copyright 2019-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralJoin.java b/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralJoin.java index ec6f128cc..e8b7b8913 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralJoin.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralJoin.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralViewAsOperator.java b/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralViewAsOperator.java index d0c08ecdb..edc4c30fc 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralViewAsOperator.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralViewAsOperator.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/OperatorTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/OperatorTransformer.java new file mode 100644 index 000000000..88e48824d --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/OperatorTransformer.java @@ -0,0 +1,354 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.google.gson.JsonPrimitive; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; + +import static com.linkedin.coral.common.calcite.CalciteUtil.createCall; +import static com.linkedin.coral.common.calcite.CalciteUtil.createLiteralBoolean; +import static com.linkedin.coral.common.calcite.CalciteUtil.createLiteralNumber; +import static com.linkedin.coral.common.calcite.CalciteUtil.createStringLiteral; + + +/** + * Object for transforming Operator from one SQL language to another SQL language at the SqlNode layer. + * + * Suppose f1(a1, a2, ..., an) in the first language can be computed by + * f2(b1, b2, ..., bm) in the second language as follows: + * (b1, b2, ..., bm) = g(a1, a2, ..., an) + * f1(a1, a2, ..., an) = h(f2(g(a1, a2, ..., an))) + * + * We need to define two transformation functions: + * - A vector function g for transforming all operands + * - A function h for transforming the result. + * + * This class will represent g and h as expressions in JSON format as follows: + * - Operators: +, -, *, /, and ^ + * - Operands: source operands and literal values + * + * There may also be situations where a function in one language can map to more than one functions in the other + * language depending on the set of input parameters. + * We define a set of matching functions to determine what function name is used. + * Currently, there is no use-case more complicated than matching a parameter string to a static regex. + * + * Example 1: + * In the input IR, TRUNCATE(aDouble, numDigitAfterDot) truncates aDouble by removing + * any digit from the position numDigitAfterDot after the dot, like truncate(11.45, 0) = 11, + * truncate(11.45, 1) = 11.4 + * + * In the target IR, TRUNCATE(aDouble) only takes one argument and removes all digits after the dot, + * like truncate(11.45) = 11. + * + * The transformation of TRUNCATE from one IR to another is represented as follows: + * 1. Target IR name: TRUNCATE + * + * 2. Operand transformers: + * g(b1) = a1 * 10 ^ a2, with JSON format: + * [ + * { "op":"*", + * "operands":[ + * {"input":1}, // input 0 is reserved for result transformer. source inputs start from 1 + * { "op":"^", + * "operands":[ + * {"value":10}, + * {"input":2}]}]}] + * + * 3. Result transformer: + * h(result) = result / 10 ^ a2 + * { "op":"/", + * "operands":[ + * {"input":0}, // input 0 is for result transformer + * { "op":"^", + * "operands":[ + * {"value":10}, + * {"input":2}]}]}] + * + * + * 4. Operator transformers: + * none + * + * Example 2: + * In the input IR, there exists a hive-derived function to decode binary data given a format, DECODE(binary, scheme). + * In the target IR, there is no generic decoding function that takes a decoding-scheme. + * Instead, there exist specific decoding functions that are first-class functions like FROM_UTF8(binary). + * Consequently, we would need to know the operands in the function in order to determine the corresponding call. + * + * The transformation of DECODE from one IR to another is represented as follows: + * 1. Target IR name: There is no function name determined at compile time. + * null + * + * 2. Operand transformers: We want to retain column 1 and drop column 2: + * [{"input":1}] + * + * 3. Result transformer: No transformation is performed on output. + * null + * + * 4. Operator transformers: Check the second parameter (scheme) matches 'utf-8' with any casing using Java Regex. + * [ { + * "regex" : "^.*(?i)(utf-8).*$", + * "input" : 2, + * "name" : "from_utf8" + * } + * ] + */ +class OperatorTransformer { + private static final Map OP_MAP = new HashMap<>(); + + // Operators allowed in the transformation + static { + OP_MAP.put("+", SqlStdOperatorTable.PLUS); + OP_MAP.put("-", SqlStdOperatorTable.MINUS); + OP_MAP.put("*", SqlStdOperatorTable.MULTIPLY); + OP_MAP.put("/", SqlStdOperatorTable.DIVIDE); + OP_MAP.put("^", SqlStdOperatorTable.POWER); + OP_MAP.put("%", SqlStdOperatorTable.MOD); + } + + public static final String OPERATOR = "op"; + public static final String OPERANDS = "operands"; + /** + * For input node: + * - input equals 0 refers to the result + * - input great than 0 refers to the index of source operand (starting from 1) + */ + public static final String INPUT = "input"; + public static final String VALUE = "value"; + public static final String REGEX = "regex"; + public static final String NAME = "name"; + + public final String fromOperatorName; + public final SqlOperator targetOperator; + public final List operandTransformers; + public final JsonObject resultTransformer; + public final List operatorTransformers; + + private OperatorTransformer(String fromOperatorName, SqlOperator targetOperator, List operandTransformers, + JsonObject resultTransformer, List operatorTransformers) { + this.fromOperatorName = fromOperatorName; + this.targetOperator = targetOperator; + this.operandTransformers = operandTransformers; + this.resultTransformer = resultTransformer; + this.operatorTransformers = operatorTransformers; + } + + /** + * Creates a new transformer. + * + * @param fromOperatorName Name of the function associated with this Operator in the input IR + * @param targetOperator Operator in the target language + * @param operandTransformers JSON string representing the operand transformations, + * null for identity transformations + * @param resultTransformer JSON string representing the result transformation, + * null for identity transformation + * @param operatorTransformers JSON string representing an array of transformers that can vary the name of the target + * operator based on runtime parameter values. + * In the order of the JSON Array, the first transformer that matches the JSON string will + * have its given operator named selected as the target operator name. + * Operands are indexed beginning at index 1. + * An operatorTransformer has the following serialized JSON string format: + * "[ + * { + * \"name\" : \"{Name of function if this matches}\", + * \"input\" : {Index of the parameter starting at index 1 that is evaluated }, + * \"regex\" : \"{Java Regex string matching the parameter at given input}\" + * }, + * ... + * ]" + * For example, a transformer for a operator named "foo" when parameter 2 matches exactly + * "bar" is specified as: + * "[ + * { + * \"name\" : \"foo\", + * \"input\" : 2, + * \"regex\" : \"'bar'\" + * } + * ]" + * NOTE: A string literal is represented exactly as ['STRING_LITERAL'] with the single + * quotation marks INCLUDED. + * As seen in the example above, the single quotation marks are also present in the + * regex matcher. + * + * @return {@link OperatorTransformer} object + */ + + public static OperatorTransformer of(@Nonnull String fromOperatorName, @Nonnull SqlOperator targetOperator, + @Nullable String operandTransformers, @Nullable String resultTransformer, @Nullable String operatorTransformers) { + List operands = null; + JsonObject result = null; + List operators = null; + if (operandTransformers != null) { + operands = parseJsonObjectsFromString(operandTransformers); + } + if (resultTransformer != null) { + result = new JsonParser().parse(resultTransformer).getAsJsonObject(); + } + if (operatorTransformers != null) { + operators = parseJsonObjectsFromString(operatorTransformers); + } + return new OperatorTransformer(fromOperatorName, targetOperator, operands, result, operators); + } + + /** + * Transforms a call to the source operator. + * + * @param sourceOperands Source operands + * @return An expression calling the target operator that is equivalent to the source operator call + */ + public SqlNode transformCall(List sourceOperands) { + final SqlOperator newTargetOperator = transformTargetOperator(targetOperator, sourceOperands); + if (newTargetOperator == null || newTargetOperator.getName().isEmpty()) { + String operands = sourceOperands.stream().map(SqlNode::toString).collect(Collectors.joining(",")); + throw new IllegalArgumentException( + String.format("An equivalent operator in the target IR was not found for the function call: %s(%s)", + fromOperatorName, operands)); + } + final List newOperands = transformOperands(sourceOperands); + final SqlCall newCall = createCall(newTargetOperator, newOperands, SqlParserPos.ZERO); + return transformResult(newCall, sourceOperands); + } + + private List transformOperands(List sourceOperands) { + if (operandTransformers == null) { + return sourceOperands; + } + final List sources = new ArrayList<>(); + // Add a dummy expression for input 0 + sources.add(null); + sources.addAll(sourceOperands); + final List results = new ArrayList<>(); + for (JsonObject operandTransformer : operandTransformers) { + results.add(transformExpression(operandTransformer, sources)); + } + return results; + } + + private SqlNode transformResult(SqlNode result, List sourceOperands) { + if (resultTransformer == null) { + return result; + } + final List sources = new ArrayList<>(); + // Result will be input 0 + sources.add(result); + sources.addAll(sourceOperands); + return transformExpression(resultTransformer, sources); + } + + /** + * Performs a single transformer. + */ + private SqlNode transformExpression(JsonObject transformer, List sourceOperands) { + if (transformer.get(OPERATOR) != null) { + final List inputOperands = new ArrayList<>(); + for (JsonElement inputOperand : transformer.getAsJsonArray(OPERANDS)) { + if (inputOperand.isJsonObject()) { + inputOperands.add(transformExpression(inputOperand.getAsJsonObject(), sourceOperands)); + } + } + final String operatorName = transformer.get(OPERATOR).getAsString(); + final SqlOperator op = OP_MAP.get(operatorName); + if (op == null) { + throw new UnsupportedOperationException("Operator " + operatorName + " is not supported in transformation"); + } + return createCall(op, inputOperands, SqlParserPos.ZERO); + } + if (transformer.get(INPUT) != null) { + int index = transformer.get(INPUT).getAsInt(); + if (index < 0 || index >= sourceOperands.size() || sourceOperands.get(index) == null) { + throw new IllegalArgumentException( + "Invalid input value: " + index + ". Number of source operands: " + sourceOperands.size()); + } + return sourceOperands.get(index); + } + final JsonElement value = transformer.get(VALUE); + if (value == null) { + throw new IllegalArgumentException("JSON node for transformation should be either op, input, or value"); + } + if (!value.isJsonPrimitive()) { + throw new IllegalArgumentException("Value should be of primitive type: " + value); + } + + final JsonPrimitive primitive = value.getAsJsonPrimitive(); + if (primitive.isString()) { + return createStringLiteral(primitive.getAsString(), SqlParserPos.ZERO); + } + if (primitive.isBoolean()) { + return createLiteralBoolean(primitive.getAsBoolean(), SqlParserPos.ZERO); + } + if (primitive.isNumber()) { + return createLiteralNumber(value.getAsBigDecimal().longValue(), SqlParserPos.ZERO); + } + + throw new UnsupportedOperationException("Invalid JSON literal value: " + primitive); + } + + /** + * Returns a SqlOperator with a function name based on the value of the source operands. + */ + private SqlOperator transformTargetOperator(SqlOperator operator, List sourceOperands) { + if (operatorTransformers == null) { + return operator; + } + + for (JsonObject operatorTransformer : operatorTransformers) { + if (!operatorTransformer.has(REGEX) || !operatorTransformer.has(INPUT) || !operatorTransformer.has(NAME)) { + throw new IllegalArgumentException( + "JSON node for target operator transformer must have a matcher, input and name"); + } + // We use the same convention as operand and result transformers. + // Therefore, we start source index values at index 1 instead of index 0. + // Acceptable index values are set to be [1, size] + int index = operatorTransformer.get(INPUT).getAsInt() - 1; + if (index < 0 || index >= sourceOperands.size()) { + throw new IllegalArgumentException( + String.format("Index is not within the acceptable range [%d, %d]", 1, sourceOperands.size())); + } + String functionName = operatorTransformer.get(NAME).getAsString(); + if (functionName.isEmpty()) { + throw new IllegalArgumentException("JSON node for transformation must have a non-empty name"); + } + String matcher = operatorTransformer.get(REGEX).getAsString(); + + if (Pattern.matches(matcher, sourceOperands.get(index).toString())) { + return Spark2CoralOperatorTransformerMapUtils.createOperator(functionName, operator.getReturnTypeInference(), + null); + } + } + return operator; + } + + /** + * Creates an ArrayList of JsonObjects from a string input. + * The input string must be a serialized JSON array. + */ + private static List parseJsonObjectsFromString(String s) { + List objects = new ArrayList<>(); + JsonArray transformerArray = new JsonParser().parse(s).getAsJsonArray(); + for (JsonElement object : transformerArray) { + objects.add(object.getAsJsonObject()); + } + return objects; + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorConverter.java new file mode 100644 index 000000000..9510df33b --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorConverter.java @@ -0,0 +1,35 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.Locale; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.util.SqlShuttle; + + +/** + * Rewrites the SqlNode tree to replace Spark SQL operators with Coral IR to obtain a Coral-compatible plan. + */ +public class Spark2CoralOperatorConverter extends SqlShuttle { + public Spark2CoralOperatorConverter() { + } + + @Override + public SqlNode visit(final SqlCall call) { + final String operatorName = call.getOperator().getName(); + + final OperatorTransformer transformer = Spark2CoralOperatorTransformerMap + .getOperatorTransformer(operatorName.toLowerCase(Locale.ROOT), call.operandCount()); + + if (transformer == null) { + return super.visit(call); + } + + return super.visit((SqlCall) transformer.transformCall(call.getOperandList())); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMap.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMap.java new file mode 100644 index 000000000..68b3ee836 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMap.java @@ -0,0 +1,34 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.HashMap; +import java.util.Map; + +import static com.linkedin.coral.spark.spark2rel.Spark2CoralOperatorTransformerMapUtils.getKey; + + +public class Spark2CoralOperatorTransformerMap { + private Spark2CoralOperatorTransformerMap() { + } + + public static final Map TRANSFORMER_MAP = new HashMap<>(); + + static { + // TODO: keep adding Spark-Specific functions as needed + } + + /** + * Gets SparkCalciteOperatorTransformer for a given Spark SQL Operator. + * + * @param sparkOpName Name of Spark SQL operator + * @param numOperands Number of operands + * @return {@link OperatorTransformer} object + */ + public static OperatorTransformer getOperatorTransformer(String sparkOpName, int numOperands) { + return TRANSFORMER_MAP.get(getKey(sparkOpName, numOperands)); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMapUtils.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMapUtils.java new file mode 100644 index 000000000..d9fa188e0 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMapUtils.java @@ -0,0 +1,81 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.Map; + +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.validate.SqlUserDefinedFunction; + + +public class Spark2CoralOperatorTransformerMapUtils { + + private Spark2CoralOperatorTransformerMapUtils() { + } + + /** + * Creates a mapping for Spark SqlOperator name to SparkCalciteOperatorTransformer. + * + * @param transformerMap Map to store the result + * @param sparkOp Spark SQL operator + * @param numOperands Number of operands + * @param calciteOperatorName Name of Calcite Operator + */ + static void createTransformerMapEntry(Map transformerMap, SqlOperator sparkOp, + int numOperands, String calciteOperatorName) { + createTransformerMapEntry(transformerMap, sparkOp, numOperands, calciteOperatorName, null, null); + } + + /** + * Creates a mapping from Spark SqlOperator name to Calcite Operator with Calcite Operator name, operands transformer, and result transformers. + * To construct Calcite SqlOperator from Calcite Operator name, this method reuses the return type inference from sparkOp, + * assuming equivalence. + * + * @param transformerMap Map to store the result + * @param sparkOp Spark SQL operator + * @param numOperands Number of operands + * @param calciteOperatorName Name of Calcite Operator + * @param operandTransformer Operand transformers, null for identity transformation + * @param resultTransformer Result transformer, null for identity transformation + */ + static void createTransformerMapEntry(Map transformerMap, SqlOperator sparkOp, + int numOperands, String calciteOperatorName, String operandTransformer, String resultTransformer) { + createTransformerMapEntry(transformerMap, sparkOp, numOperands, + createOperator(calciteOperatorName, sparkOp.getReturnTypeInference(), sparkOp.getOperandTypeChecker()), + operandTransformer, resultTransformer); + } + + /** + * Creates a mapping from Spark SqlOperator name to Calcite UDF with Calcite SqlOperator, operands transformer, and result transformers. + * + * @param transformerMap Map to store the result + * @param sparkOp Spark SQL operator + * @param numOperands Number of operands + * @param calciteSqlOperator The Calcite Sql Operator that is used as the target operator in the map + * @param operandTransformer Operand transformers, null for identity transformation + * @param resultTransformer Result transformer, null for identity transformation + */ + static void createTransformerMapEntry(Map transformerMap, SqlOperator sparkOp, + int numOperands, SqlOperator calciteSqlOperator, String operandTransformer, String resultTransformer) { + + transformerMap.put(getKey(sparkOp.getName(), numOperands), + OperatorTransformer.of(sparkOp.getName(), calciteSqlOperator, operandTransformer, resultTransformer, null)); + } + + static SqlOperator createOperator(String functionName, SqlReturnTypeInference returnTypeInference, + SqlOperandTypeChecker operandTypeChecker) { + return new SqlUserDefinedFunction(new SqlIdentifier(functionName, SqlParserPos.ZERO), returnTypeInference, null, + operandTypeChecker, null, null); + } + + static String getKey(String sparkOpName, int numOperands) { + return sparkOpName + "_" + numOperands; + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlConformance.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlConformance.java new file mode 100644 index 000000000..d18605ea9 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlConformance.java @@ -0,0 +1,20 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import org.apache.calcite.sql.validate.SqlConformance; +import org.apache.calcite.sql.validate.SqlConformanceEnum; +import org.apache.calcite.sql.validate.SqlDelegatingConformance; + + +public class SparkSqlConformance extends SqlDelegatingConformance { + + public static final SqlConformance SPARK_SQL = new SparkSqlConformance(); + + private SparkSqlConformance() { + super(SqlConformanceEnum.PRAGMATIC_2003); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlToRelConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlToRelConverter.java new file mode 100644 index 000000000..7dda0fd09 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlToRelConverter.java @@ -0,0 +1,113 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.core.Uncollect; +import org.apache.calcite.rel.logical.LogicalValues; +import org.apache.calcite.rel.metadata.JaninoRelMetadataProvider; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlExplainFormat; +import org.apache.calcite.sql.SqlExplainLevel; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlUnnestOperator; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql2rel.SqlRexConvertletTable; +import org.apache.calcite.sql2rel.SqlToRelConverter; + +import com.linkedin.coral.common.HiveUncollect; +import com.linkedin.coral.hive.hive2rel.functions.HiveExplodeOperator; + + +/** + * Class to convert Spark SQL to Calcite RelNode. This class + * specializes the functionality provided by {@link SqlToRelConverter}. + */ +class SparkSqlToRelConverter extends SqlToRelConverter { + + SparkSqlToRelConverter(RelOptTable.ViewExpander viewExpander, SqlValidator validator, + Prepare.CatalogReader catalogReader, RelOptCluster cluster, SqlRexConvertletTable convertletTable, + Config config) { + super(viewExpander, validator, catalogReader, cluster, convertletTable, config); + } + + // This differs from base class in two ways: + // 1. This does not validate the type of converted rel rowType with that of validated node. This is because + // hive is lax in enforcing view schemas. + // 2. This skips calling some methods because (1) those are private, and (2) not required for our usecase + public RelRoot convertQuery(SqlNode query, final boolean needsValidation, final boolean top) { + if (needsValidation) { + query = validator.validate(query); + } + + RelMetadataQuery.THREAD_PROVIDERS.set(JaninoRelMetadataProvider.of(cluster.getMetadataProvider())); + RelNode result = convertQueryRecursive(query, top, null).rel; + RelCollation collation = RelCollations.EMPTY; + + if (SQL2REL_LOGGER.isDebugEnabled()) { + SQL2REL_LOGGER.debug(RelOptUtil.dumpPlan("Plan after converting SqlNode to RelNode", result, + SqlExplainFormat.TEXT, SqlExplainLevel.EXPPLAN_ATTRIBUTES)); + } + + final RelDataType validatedRowType = validator.getValidatedNodeType(query); + return RelRoot.of(result, validatedRowType, query.getKind()).withCollation(collation); + } + + @Override + protected void convertFrom(Blackboard bb, SqlNode from) { + if (from == null) { + super.convertFrom(bb, from); + return; + } + switch (from.getKind()) { + case UNNEST: + convertUnnestFrom(bb, from); + break; + default: + super.convertFrom(bb, from); + break; + } + } + + private void convertUnnestFrom(Blackboard bb, SqlNode from) { + final SqlCall call; + call = (SqlCall) from; + final List nodes = call.getOperandList(); + final SqlUnnestOperator operator = (SqlUnnestOperator) call.getOperator(); + // FIXME: base class calls 'replaceSubqueries for operands here but that's a private + // method. This is not an issue for our usecases with hive but we may need handling in future + final List exprs = new ArrayList<>(); + final List fieldNames = new ArrayList<>(); + for (Ord node : Ord.zip(nodes)) { + exprs.add(bb.convertExpression(node.e)); + // In Hive, "LATERAL VIEW EXPLODE(arr) t" is equivalent to "LATERAL VIEW EXPLODE(arr) t AS col". + // Use the default column name "col" if not specified. + fieldNames.add(node.e.getKind() == SqlKind.AS ? validator.deriveAlias(node.e, node.i) + : HiveExplodeOperator.ARRAY_ELEMENT_COLUMN_NAME); + } + final RelNode input = RelOptUtil.createProject((null != bb.root) ? bb.root : LogicalValues.createOneRow(cluster), + exprs, fieldNames, true); + Uncollect uncollect = + new HiveUncollect(cluster, cluster.traitSetOf(Convention.NONE), input, operator.withOrdinality); + bb.setRoot(uncollect, true); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverter.java new file mode 100644 index 000000000..f0a49cfc9 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverter.java @@ -0,0 +1,98 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.volcano.VolcanoPlanner; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.util.ChainedSqlOperatorTable; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql2rel.SqlRexConvertletTable; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.hadoop.hive.metastore.api.Table; + +import com.linkedin.coral.common.HiveMetastoreClient; +import com.linkedin.coral.common.HiveRelBuilder; +import com.linkedin.coral.common.ToRelConverter; +import com.linkedin.coral.hive.hive2rel.DaliOperatorTable; +import com.linkedin.coral.hive.hive2rel.HiveConvertletTable; +import com.linkedin.coral.hive.hive2rel.HiveSqlValidator; +import com.linkedin.coral.hive.hive2rel.functions.HiveFunctionResolver; +import com.linkedin.coral.hive.hive2rel.functions.StaticHiveFunctionRegistry; +import com.linkedin.coral.spark.spark2rel.parsetree.SparkParserDriver; + +import static com.linkedin.coral.spark.spark2rel.SparkSqlConformance.SPARK_SQL; + + +/* + * We provide this class as a public interface by providing a thin wrapper + * around SparkSqlToRelConverter. Directly using SparkSqlToRelConverter will + * expose public methods from SqlToRelConverter. Use of SqlToRelConverter + * is likely to change in the future if we want more control over the + * conversion process. This class abstracts that out. + */ +public class SparkToRelConverter extends ToRelConverter { + private final HiveFunctionResolver functionResolver = + new HiveFunctionResolver(new StaticHiveFunctionRegistry(), new ConcurrentHashMap<>()); + private final + // The validator must be reused + SqlValidator sqlValidator = new HiveSqlValidator(getOperatorTable(), getCalciteCatalogReader(), + ((JavaTypeFactory) getRelBuilder().getTypeFactory()), SPARK_SQL); + + public SparkToRelConverter(HiveMetastoreClient hiveMetastoreClient) { + super(hiveMetastoreClient); + } + + public SparkToRelConverter(Map>> localMetaStore) { + super(localMetaStore); + } + + @Override + protected SqlRexConvertletTable getConvertletTable() { + return new HiveConvertletTable(); + } + + @Override + protected SqlValidator getSqlValidator() { + return sqlValidator; + } + + @Override + protected SqlOperatorTable getOperatorTable() { + return ChainedSqlOperatorTable.of(SqlStdOperatorTable.instance(), new DaliOperatorTable(functionResolver)); + } + + @Override + protected SqlToRelConverter getSqlToRelConverter() { + return new SparkSqlToRelConverter(new SparkViewExpander(this), getSqlValidator(), getCalciteCatalogReader(), + RelOptCluster.create(new VolcanoPlanner(), getRelBuilder().getRexBuilder()), getConvertletTable(), + SqlToRelConverter.configBuilder().withRelBuilderFactory(HiveRelBuilder.LOGICAL_BUILDER).build()); + } + + @Override + protected SqlNode toSqlNode(String sql, Table sparkView) { + String trimmedSql = trimParenthesis(sql.toUpperCase()); + SqlNode parsedSqlNode = SparkParserDriver.parse(trimmedSql); + SqlNode convertedSqlNode = parsedSqlNode.accept(new Spark2CoralOperatorConverter()); + return convertedSqlNode; + } + + private static String trimParenthesis(String value) { + String str = value.trim(); + if (str.startsWith("(") && str.endsWith(")")) { + return trimParenthesis(str.substring(1, str.length() - 1)); + } + return str; + } + +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkViewExpander.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkViewExpander.java new file mode 100644 index 000000000..f4958c2d6 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkViewExpander.java @@ -0,0 +1,51 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.List; + +import javax.annotation.Nonnull; + +import com.google.common.base.Preconditions; + +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.util.Util; + +import com.linkedin.coral.common.FuzzyUnionSqlRewriter; + + +/** + * Class that implements {@link RelOptTable.ViewExpander} + * interface to support expansion of Spark Views to relational algebra. + */ +public class SparkViewExpander implements RelOptTable.ViewExpander { + + private final SparkToRelConverter sparkToRelConverter; + /** + * Instantiates a new Spark view expander. + * + * @param sparkToRelConverter Spark to Rel converter + */ + public SparkViewExpander(@Nonnull SparkToRelConverter sparkToRelConverter) { + this.sparkToRelConverter = sparkToRelConverter; + } + + @Override + public RelRoot expandView(RelDataType rowType, String queryString, List schemaPath, List viewPath) { + Preconditions.checkNotNull(viewPath); + Preconditions.checkState(!viewPath.isEmpty()); + + String dbName = Util.last(schemaPath); + String tableName = viewPath.get(0); + + SqlNode sqlNode = sparkToRelConverter.processView(dbName, tableName) + .accept(new FuzzyUnionSqlRewriter(tableName, sparkToRelConverter)); + return sparkToRelConverter.getSqlToRelConverter().convertQuery(sqlNode, true, true); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkParserDriver.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkParserDriver.java new file mode 100644 index 000000000..21fa4b026 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkParserDriver.java @@ -0,0 +1,28 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel.parsetree; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.apache.calcite.sql.SqlNode; +import org.apache.spark.sql.catalyst.parser.SqlBaseParser; +import org.apache.spark.sql.execution.SparkSqlParser; + + +public class SparkParserDriver { + private static final SparkSqlParser SPARK_SQL_PARSER = new SparkSqlParser(); + + /** + * Use the SparkSqlParser to parse the command and return the Calcite SqlNode. + * + * @param command Spark SQL + * @return {@link SqlNode} as response + */ + public static SqlNode parse(String command) { + ParserRuleContext context = SPARK_SQL_PARSER.parse(command, new SparkSqlAstBuilder()); + SparkSqlAstVisitor visitor = new SparkSqlAstVisitor(); + return visitor.visitSingleStatement((SqlBaseParser.SingleStatementContext) context); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstBuilder.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstBuilder.java new file mode 100644 index 000000000..c8b64f1e1 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstBuilder.java @@ -0,0 +1,20 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel.parsetree; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.apache.spark.sql.catalyst.parser.AstBuilder; +import org.apache.spark.sql.catalyst.parser.SqlBaseParser; + +import scala.Function1; + + +public class SparkSqlAstBuilder extends AstBuilder implements Function1 { + @Override + public ParserRuleContext apply(SqlBaseParser v1) { + return v1.singleStatement(); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstVisitor.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstVisitor.java new file mode 100644 index 000000000..d9e666be8 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstVisitor.java @@ -0,0 +1,381 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel.parsetree; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.RuleNode; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOrderBy; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlUnresolvedFunction; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeFactoryImpl; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.spark.sql.catalyst.parser.SqlBaseBaseVisitor; +import org.apache.spark.sql.catalyst.parser.SqlBaseParser; + +import com.linkedin.coral.common.HiveTypeSystem; +import com.linkedin.coral.common.calcite.CalciteUtil; + +import static com.linkedin.coral.common.calcite.CalciteUtil.createCall; +import static com.linkedin.coral.common.calcite.CalciteUtil.createSqlIdentifier; +import static com.linkedin.coral.common.calcite.CalciteUtil.createStarIdentifier; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.AS; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.BETWEEN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.CAST; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.CONCAT; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.DIVIDE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.DIVIDE_INTEGER; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.EQUALS; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.GREATER_THAN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.GREATER_THAN_OR_EQUAL; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IS_DISTINCT_FROM; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IS_FALSE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IS_NULL; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IS_TRUE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IS_UNKNOWN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LESS_THAN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LESS_THAN_OR_EQUAL; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LIKE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.MINUS; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.MULTIPLY; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.NOT; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.NOT_EQUALS; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.PERCENT_REMAINDER; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.PLUS; +import static org.apache.calcite.sql.parser.SqlParserPos.ZERO; + + +public class SparkSqlAstVisitor extends SqlBaseBaseVisitor { + private static final SqlTypeFactoryImpl SQL_TYPE_FACTORY = new SqlTypeFactoryImpl(new HiveTypeSystem()); + private static final String UNSUPPORTED_EXCEPTION_MSG = "%s at line %d column %d is not supported in the visit."; + + @Override + public SqlNode visitChildren(RuleNode node) { + if (node.getChildCount() == 2 && node.getChild(1) instanceof TerminalNode) { + return node.getChild(0).accept(this); + } + return super.visitChildren(node); + } + + @Override + public SqlNode visitQuery(SqlBaseParser.QueryContext ctx) { + SqlSelect select = (SqlSelect) ((SqlNodeList) visit(ctx.queryTerm())).get(0); + SqlNodeList orderBy = ctx.queryOrganization().order.isEmpty() ? null : new SqlNodeList( + ctx.queryOrganization().order.stream().map(this::visit).collect(Collectors.toList()), getPos(ctx)); + SqlNode limit = ctx.queryOrganization().limit == null ? null : visit(ctx.queryOrganization().limit); + if (orderBy != null || limit != null) { + return new SqlOrderBy(getPos(ctx), select, orderBy, null, limit); + } + return select; + } + + @Override + public SqlNode visitRegularQuerySpecification(SqlBaseParser.RegularQuerySpecificationContext ctx) { + SqlNodeList selectList = visitSelectClause(ctx.selectClause()); + SqlNode from = ctx.fromClause() == null ? null : visitFromClause(ctx.fromClause()); + SqlNode where = ctx.whereClause() == null ? null : visitWhereClause(ctx.whereClause()); + SqlNodeList groupBy = + ctx.aggregationClause() == null ? null : visitGroupByClause(ctx.aggregationClause().groupByClause); + SqlNode having = ctx.havingClause() == null ? null : visitHavingClause(ctx.havingClause()); + return new SqlSelect(getPos(ctx), null, selectList, from, where, groupBy, having, null, null, null, null); + } + + @Override + public SqlNode visitWhereClause(SqlBaseParser.WhereClauseContext ctx) { + return super.visitWhereClause(ctx); + } + + @Override + public SqlNode visitCtes(SqlBaseParser.CtesContext ctx) { + throw new UnhandledASTNodeException(ctx, UNSUPPORTED_EXCEPTION_MSG); + } + + @Override + public SqlNode visitJoinRelation(SqlBaseParser.JoinRelationContext ctx) { + throw new UnhandledASTNodeException(ctx, UNSUPPORTED_EXCEPTION_MSG); + } + + @Override + public SqlNode visitJoinType(SqlBaseParser.JoinTypeContext ctx) { + throw new UnhandledASTNodeException(ctx, UNSUPPORTED_EXCEPTION_MSG); + } + + @Override + public SqlNode visitJoinCriteria(SqlBaseParser.JoinCriteriaContext ctx) { + throw new UnhandledASTNodeException(ctx, UNSUPPORTED_EXCEPTION_MSG); + } + + @Override + public SqlNode visitPredicated(SqlBaseParser.PredicatedContext ctx) { + SqlNode expression = visit(ctx.valueExpression()); + if (ctx.predicate() != null) { + return withPredicate(expression, ctx.predicate()); + } + return expression; + } + + private SqlNode withPredicate(SqlNode expression, SqlBaseParser.PredicateContext ctx) { + SqlNode predicate = toPredicate(expression, ctx); + if (ctx.NOT() == null) { + return predicate; + } + return NOT.createCall(getPos(ctx), predicate); + } + + private SqlNode toPredicate(SqlNode expression, SqlBaseParser.PredicateContext ctx) { + SqlParserPos position = getPos(ctx); + int type = ctx.kind.getType(); + switch (type) { + case SqlBaseParser.BETWEEN: + return BETWEEN.createCall(position, expression, visit(ctx.right)); + case SqlBaseParser.IN: + return IN.createCall(position, expression, visit(ctx.right)); + case SqlBaseParser.LIKE: + return LIKE.createCall(position, expression, visit(ctx.right)); + case SqlBaseParser.RLIKE: + throw new UnsupportedOperationException("Unsupported predicate type: RLIKE"); + case SqlBaseParser.NULL: + return IS_NULL.createCall(position, expression); + case SqlBaseParser.TRUE: + return IS_TRUE.createCall(position, expression); + case SqlBaseParser.FALSE: + return IS_FALSE.createCall(position, expression); + case SqlBaseParser.UNKNOWN: + return IS_UNKNOWN.createCall(position, expression); + case SqlBaseParser.DISTINCT: + return IS_DISTINCT_FROM.createCall(position, expression, visit(ctx.right)); + default: + throw new UnsupportedOperationException("Unsupported predicate type:" + type); + } + } + + @Override + public SqlNode visitComparison(SqlBaseParser.ComparisonContext ctx) { + SqlParserPos position = getPos(ctx); + SqlNode left = visit(ctx.left); + SqlNode right = visit(ctx.right); + TerminalNode operator = (TerminalNode) ctx.comparisonOperator().getChild(0); + switch (operator.getSymbol().getType()) { + case SqlBaseParser.EQ: + return EQUALS.createCall(position, left, right); + case SqlBaseParser.NSEQ: + throw new UnsupportedOperationException("Unsupported operator: NSEQ (NullSafeEqual)"); + case SqlBaseParser.NEQ: + case SqlBaseParser.NEQJ: + return NOT_EQUALS.createCall(position, left, right); + case SqlBaseParser.LT: + return LESS_THAN.createCall(position, left, right); + case SqlBaseParser.LTE: + return LESS_THAN_OR_EQUAL.createCall(position, left, right); + case SqlBaseParser.GT: + return GREATER_THAN.createCall(position, left, right); + case SqlBaseParser.GTE: + return GREATER_THAN_OR_EQUAL.createCall(position, left, right); + } + throw new UnsupportedOperationException("visitComparison"); + } + + @Override + public SqlNodeList visitSelectClause(SqlBaseParser.SelectClauseContext ctx) { + return getChildSqlNodeList(ctx.namedExpressionSeq()); + } + + @Override + public SqlNodeList visitGroupByClause(SqlBaseParser.GroupByClauseContext ctx) { + return getChildSqlNodeList(ctx.children); + } + + @Override + public SqlNode visitTableIdentifier(SqlBaseParser.TableIdentifierContext ctx) { + return createSqlIdentifier(getPos(ctx), ctx.db.getText(), ctx.table.getText()); + } + + @Override + public SqlNode visitTableName(SqlBaseParser.TableNameContext ctx) { + if (ctx.tableAlias().children != null) { + List operands = new ArrayList<>(); + operands.add(visit(ctx.getChild(0))); + operands.addAll(visitTableAlias(ctx.tableAlias()).getList()); + return AS.createCall(getPos(ctx), operands); + } + return visitMultipartIdentifier(ctx.multipartIdentifier()); + } + + @Override + public SqlNodeList visitTableAlias(SqlBaseParser.TableAliasContext ctx) { + List operands = new ArrayList<>(); + operands.add(visit(ctx.getChild(0))); + if (ctx.children.size() > 1) { + operands.addAll(((SqlNodeList) visit(ctx.getChild(1))).getList()); + } + return new SqlNodeList(operands, getPos(ctx)); + } + + @Override + public SqlNode visitQuotedIdentifierAlternative(SqlBaseParser.QuotedIdentifierAlternativeContext ctx) { + return createSqlIdentifier(getPos(ctx), ctx.getText()); + } + + @Override + public SqlNode visitUnquotedIdentifier(SqlBaseParser.UnquotedIdentifierContext ctx) { + return createSqlIdentifier(getPos(ctx), ctx.getText()); + } + + @Override + public SqlNode visitMultipartIdentifier(SqlBaseParser.MultipartIdentifierContext ctx) { + return createSqlIdentifier(getPos(ctx), + ctx.parts.stream().map(part -> part.identifier().getText()).toArray(String[]::new)); + } + + @Override + public SqlNode visitCast(SqlBaseParser.CastContext ctx) { + SqlDataTypeSpec spec = toSqlDataTypeSpec(ctx); + return CAST.createCall(getPos(ctx), visit(ctx.expression()), spec); + } + + private SqlDataTypeSpec toSqlDataTypeSpec(SqlBaseParser.CastContext ctx) { + return SqlTypeUtil.convertTypeToSpec( + SQL_TYPE_FACTORY.createSqlType(SqlTypeName.valueOf((ctx.dataType().getText().toUpperCase())))); + } + + @Override + public SqlNode visitFunctionCall(SqlBaseParser.FunctionCallContext ctx) { + SqlIdentifier functionName = createSqlIdentifier(getPos(ctx), ctx.functionName().getText()); + SqlUnresolvedFunction unresolvedFunction = + new SqlUnresolvedFunction(functionName, null, null, null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION); + List operands = ctx.argument.stream().map(this::visit).collect(Collectors.toList()); + return createCall(unresolvedFunction, operands, getPos(ctx)); + } + + @Override + public SqlNode visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) { + return ctx.query().accept(this); + } + + @Override + public SqlNode visitQueryTermDefault(SqlBaseParser.QueryTermDefaultContext ctx) { + return getChildSqlNodeList(ctx); + } + + @Override + public SqlNode visitSubscript(SqlBaseParser.SubscriptContext ctx) { + throw new UnhandledASTNodeException(ctx, UNSUPPORTED_EXCEPTION_MSG); + } + + @Override + public SqlNode visitNamedExpression(SqlBaseParser.NamedExpressionContext ctx) { + if (ctx.name != null) { + return AS.createCall(getPos(ctx), visit(ctx.getChild(0)), + createSqlIdentifier(getPos(ctx), ctx.name.identifier().getText())); + } + String text = ctx.getText(); + if (text.equals("*")) { + return createStarIdentifier(getPos(ctx)); + } + return super.visitNamedExpression(ctx); + } + + @Override + public SqlNode visitIdentifierList(SqlBaseParser.IdentifierListContext ctx) { + List identifiers = ctx.identifierSeq().ident.stream() + .map(identifier -> createSqlIdentifier(getPos(identifier), identifier.getText())).collect(Collectors.toList()); + return new SqlNodeList(identifiers, getPos(ctx)); + } + + @Override + public SqlNode visitColumnReference(SqlBaseParser.ColumnReferenceContext ctx) { + return createSqlIdentifier(getPos(ctx), ctx.getText()); + } + + @Override + public SqlNode visitArithmeticBinary(SqlBaseParser.ArithmeticBinaryContext ctx) { + SqlNode left = visit(ctx.left); + SqlNode right = visit(ctx.right); + SqlParserPos position = getPos(ctx); + switch (ctx.operator.getType()) { + case SqlBaseParser.ASTERISK: + return MULTIPLY.createCall(position, left, right); + case SqlBaseParser.SLASH: + return DIVIDE.createCall(position, left, right); + case SqlBaseParser.PERCENT: + return PERCENT_REMAINDER.createCall(position, left, right); + case SqlBaseParser.DIV: + return DIVIDE_INTEGER.createCall(position, left, right); + case SqlBaseParser.PLUS: + return PLUS.createCall(position, left, right); + case SqlBaseParser.MINUS: + return MINUS.createCall(position, left, right); + case SqlBaseParser.CONCAT_PIPE: + return CONCAT.createCall(position, left, right); + case SqlBaseParser.AMPERSAND: + throw new UnsupportedOperationException("Unsupported arithmetic binary: &"); + case SqlBaseParser.HAT: + throw new UnsupportedOperationException("Unsupported arithmetic binary: ^"); + case SqlBaseParser.PIPE: + throw new UnsupportedOperationException("Unsupported arithmetic binary: |"); + } + throw new UnsupportedOperationException("Unsupported arithmetic binary: " + ctx.operator); + } + + @Override + public SqlNode visitBooleanLiteral(SqlBaseParser.BooleanLiteralContext ctx) { + boolean value = ctx.getText().equalsIgnoreCase("true"); + return CalciteUtil.createLiteralBoolean(value, getPos(ctx)); + } + + @Override + public SqlNode visitNumericLiteral(SqlBaseParser.NumericLiteralContext ctx) { + return SqlLiteral.createExactNumeric(ctx.getText(), getPos(ctx)); + } + + @Override + public SqlNode visitStringLiteral(SqlBaseParser.StringLiteralContext ctx) { + String text = ctx.getText().substring(1, ctx.getText().length() - 1); + return CalciteUtil.createStringLiteral(text, getPos(ctx)); + } + + private SqlNodeList getChildSqlNodeList(ParserRuleContext ctx) { + return new SqlNodeList(getChildren(ctx), getPos(ctx)); + } + + private SqlNodeList getChildSqlNodeList(List nodes) { + return new SqlNodeList(toListOfSqlNode(nodes), ZERO); + } + + private List getChildren(ParserRuleContext node) { + return toListOfSqlNode(node.children); + } + + private List toListOfSqlNode(List nodes) { + if (nodes == null) { + return Collections.emptyList(); + } + return nodes.stream().map(this::visit).filter(Objects::nonNull).collect(Collectors.toList()); + } + + private SqlParserPos getPos(ParserRuleContext ctx) { + if (ctx.start != null) { + return new SqlParserPos(ctx.start.getLine(), ctx.start.getStartIndex()); + } + return ZERO; + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/UnhandledASTNodeException.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/UnhandledASTNodeException.java new file mode 100644 index 000000000..611b11fc4 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/UnhandledASTNodeException.java @@ -0,0 +1,23 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel.parsetree; + +import org.antlr.v4.runtime.ParserRuleContext; + + +public class UnhandledASTNodeException extends RuntimeException { + public UnhandledASTNodeException(ParserRuleContext context, String message) { + super(String.format(message, context.getClass().getSimpleName(), getLine(context), getColumn(context))); + } + + private static int getLine(ParserRuleContext context) { + return context.start.getLine(); + } + + private static int getColumn(ParserRuleContext context) { + return context.start.getStartIndex(); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java index 27d6884b1..7a0b5fb4f 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java @@ -1,5 +1,5 @@ /** - * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToLinkedInHiveUDFTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToLinkedInHiveUDFTransformer.java index a727ca37c..cda8ccb28 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToLinkedInHiveUDFTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToLinkedInHiveUDFTransformer.java @@ -1,5 +1,5 @@ /** - * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FuzzyUnionGenericProjectTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FuzzyUnionGenericProjectTransformer.java index 42ec14e8b..7655f954f 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FuzzyUnionGenericProjectTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FuzzyUnionGenericProjectTransformer.java @@ -1,5 +1,5 @@ /** - * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportUDFTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportUDFTransformer.java index 272fd197d..cf522b910 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportUDFTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportUDFTransformer.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java b/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java index dc0ae839c..ba943fdcc 100644 --- a/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java +++ b/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java b/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java index 9d68f3c6d..69e246662 100644 --- a/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java +++ b/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF.java b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF.java index f4bb82491..2554ace8e 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF.java +++ b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2020 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF2.java b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF2.java index 1760817e8..dbcedb20a 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF2.java +++ b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF2.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2020 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDTF.java b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDTF.java index 980c06621..e8fa7379e 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDTF.java +++ b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDTF.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUdfSquare.java b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUdfSquare.java index 61031127a..1e0515e87 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUdfSquare.java +++ b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUdfSquare.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2020 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUnsupportedUDF.java b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUnsupportedUDF.java index 134119268..2a031915a 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUnsupportedUDF.java +++ b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUnsupportedUDF.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2020 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java index 666ea87e3..19e1c346c 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/FuzzyUnionViewTest.java b/coral-spark/src/test/java/com/linkedin/coral/spark/FuzzyUnionViewTest.java index 0c3634b05..40a98f6b0 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/spark/FuzzyUnionViewTest.java +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/FuzzyUnionViewTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2019-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java b/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java index 845e4ba39..1a97883d7 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverterTest.java b/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverterTest.java new file mode 100644 index 000000000..987ca8c25 --- /dev/null +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverterTest.java @@ -0,0 +1,173 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.Map; + +import com.google.common.collect.ImmutableList; + +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.commons.io.FileUtils; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.MetaException; +import org.apache.hadoop.hive.ql.metadata.Hive; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import com.linkedin.coral.common.HiveMscAdapter; +import com.linkedin.coral.hive.hive2rel.functions.StaticHiveFunctionRegistry; +import com.linkedin.coral.trino.rel2trino.RelToTrinoConverter; + +import static com.linkedin.coral.spark.spark2rel.Spark2CoralOperatorTransformerMapUtils.createOperator; +import static com.linkedin.coral.spark.spark2rel.Spark2CoralOperatorTransformerMapUtils.createTransformerMapEntry; +import static com.linkedin.coral.spark.spark2rel.ToRelTestUtils.CORAL_FROM_TRINO_TEST_DIR; +import static com.linkedin.coral.spark.spark2rel.ToRelTestUtils.sparkToRelConverter; +import static org.apache.calcite.sql.type.OperandTypes.NILADIC; +import static org.apache.calcite.sql.type.OperandTypes.NUMERIC; +import static org.apache.calcite.sql.type.OperandTypes.NUMERIC_NUMERIC; +import static org.apache.calcite.sql.type.OperandTypes.or; +import static org.testng.AssertJUnit.assertEquals; + + +public class SparkToRelConverterTest { + private static HiveConf conf; + + @BeforeClass + public void beforeClass() throws HiveException, IOException, MetaException { + // Simulating a Coral environment where "foo" exists + StaticHiveFunctionRegistry.createAddUserDefinedFunction("foo", ReturnTypes.INTEGER, + or(NILADIC, NUMERIC, NUMERIC_NUMERIC)); + + conf = ToRelTestUtils.loadResourceHiveConf(); + ToRelTestUtils.initializeViews(conf); + + Map TRANSFORMER_MAP = Spark2CoralOperatorTransformerMap.TRANSFORMER_MAP; + + // foo(a) or foo() + createTransformerMapEntry(TRANSFORMER_MAP, createOperator("foo", ReturnTypes.INTEGER, or(NILADIC, NUMERIC)), 1, + "foo", null, null); + + // foo(a, b) => foo((10 * a) + (10 * b)) + createTransformerMapEntry(TRANSFORMER_MAP, createOperator("foo", ReturnTypes.INTEGER, NUMERIC_NUMERIC), 2, "foo", + "[{\"op\":\"+\",\"operands\":[{\"op\":\"*\",\"operands\":[{\"value\":10},{\"input\":1}]},{\"op\":\"*\",\"operands\":[{\"value\":10},{\"input\":2}]}]}]", + null); + } + + @AfterTest + public void afterClass() throws IOException { + FileUtils.deleteDirectory(new File(conf.get(CORAL_FROM_TRINO_TEST_DIR))); + } + + @DataProvider(name = "support") + public Iterator getSupportedSql() { + return ImmutableList. builder() + .add(new Spark2TrinoDataProvider("select * from foo", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT *\n" + "FROM \"default\".\"foo\" AS \"foo\"")) + .add(new Spark2TrinoDataProvider("select * from foo /* end */", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT *\n" + "FROM \"default\".\"foo\" AS \"foo\"")) + .add(new Spark2TrinoDataProvider("/* start */ select * from foo", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT *\n" + "FROM \"default\".\"foo\" AS \"foo\"")) + .add(new Spark2TrinoDataProvider("/* start */ select * /* middle */ from foo /* end */", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT *\n" + "FROM \"default\".\"foo\" AS \"foo\"")) + .add(new Spark2TrinoDataProvider("-- start \n select * -- junk -- hi\n from foo -- done", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT *\n" + "FROM \"default\".\"foo\" AS \"foo\"")) + .add(new Spark2TrinoDataProvider("select * from foo a (v, w, x, y, z)", + "LogicalProject(V=[$0], W=[$1], X=[$2], Y=[$3], Z=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT \"foo\".\"show\" AS \"V\", \"foo\".\"a\" AS \"W\", \"foo\".\"b\" AS \"X\", \"foo\".\"x\" AS \"Y\", \"foo\".\"y\" AS \"Z\"\n" + + "FROM \"default\".\"foo\" AS \"foo\"")) + .add(new Spark2TrinoDataProvider("select *, 123, * from foo", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4], EXPR$5=[123], show0=[$0], a0=[$1], b0=[$2], x0=[$3], y0=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT \"foo\".\"show\" AS \"show\", \"foo\".\"a\" AS \"a\", \"foo\".\"b\" AS \"b\", \"foo\".\"x\" AS \"x\", \"foo\".\"y\" AS \"y\", 123, \"foo\".\"show\" AS \"show0\", \"foo\".\"a\" AS \"a0\", \"foo\".\"b\" AS \"b0\", \"foo\".\"x\" AS \"x0\", \"foo\".\"y\" AS \"y0\"\n" + + "FROM \"default\".\"foo\" AS \"foo\"")) + .add(new Spark2TrinoDataProvider("select show from foo", + "LogicalProject(SHOW=[$0])\n" + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT \"foo\".\"show\" AS \"SHOW\"\n" + "FROM \"default\".\"foo\" AS \"foo\"")) + .add(new Spark2TrinoDataProvider("select 1 + 13 || '15' from foo", + "LogicalProject(EXPR$0=[||(CAST(+(1, 13)):VARCHAR(65535) NOT NULL, '15')])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT CAST(1 + 13 AS VARCHAR(65535)) || '15'\n" + "FROM \"default\".\"foo\" AS \"foo\"")) + .add(new Spark2TrinoDataProvider("select x is distinct from y from foo where a is not distinct from b", + "LogicalProject(EXPR$0=[AND(OR(IS NOT NULL($3), IS NOT NULL($4)), IS NOT TRUE(=($3, $4)))])\n" + + " LogicalFilter(condition=[NOT(AND(OR(IS NOT NULL($1), IS NOT NULL($2)), IS NOT TRUE(=($1, $2))))])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT (\"foo\".\"x\" IS NOT NULL OR \"foo\".\"y\" IS NOT NULL) AND \"foo\".\"x\" = \"foo\".\"y\" IS NOT TRUE\n" + + "FROM \"default\".\"foo\" AS \"foo\"\n" + + "WHERE NOT ((\"foo\".\"a\" IS NOT NULL OR \"foo\".\"b\" IS NOT NULL) AND \"foo\".\"a\" = \"foo\".\"b\" IS NOT TRUE)")) + .add(new Spark2TrinoDataProvider("select cast('123' as bigint)", + "LogicalProject(EXPR$0=[CAST('123'):BIGINT])\n" + " LogicalValues(tuples=[[{ 0 }]])\n", + "SELECT CAST('123' AS BIGINT)\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")")) + .add(new Spark2TrinoDataProvider("select a `my price` from `foo` `ORDERS`", + "LogicalProject(MY PRICE=[$1])\n" + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT \"foo\".\"a\" AS \"MY PRICE\"\n" + "FROM \"default\".\"foo\" AS \"foo\"")) + .add(new Spark2TrinoDataProvider("select * from a limit all", + "LogicalProject(b=[$0], id=[$1], x=[$2])\n" + " LogicalTableScan(table=[[hive, default, a]])\n", + "SELECT *\n" + "FROM \"default\".\"a\" AS \"a\"")) + .add(new Spark2TrinoDataProvider("select * from a order by x limit all", + "LogicalSort(sort0=[$2], dir0=[ASC-nulls-first])\n" + " LogicalProject(b=[$0], id=[$1], x=[$2])\n" + + " LogicalTableScan(table=[[hive, default, a]])\n", + "SELECT *\n" + "FROM \"default\".\"a\" AS \"a\"\n" + "ORDER BY \"a\".\"x\" NULLS FIRST")) + .add(new Spark2TrinoDataProvider("select foo(3)", + "LogicalProject(EXPR$0=[foo(3)])\n" + " LogicalValues(tuples=[[{ 0 }]])\n", + "SELECT \"foo\"(3)\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")")) + .add(new Spark2TrinoDataProvider("select FOO(3)", + "LogicalProject(EXPR$0=[foo(3)])\n" + " LogicalValues(tuples=[[{ 0 }]])\n", + "SELECT \"foo\"(3)\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")")) + .add(new Spark2TrinoDataProvider("select foo()", + "LogicalProject(EXPR$0=[foo()])\n" + " LogicalValues(tuples=[[{ 0 }]])\n", + "SELECT \"foo\"()\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")")) + .add(new Spark2TrinoDataProvider("select foo(10, 2)", + "LogicalProject(EXPR$0=[foo(+(*(10, 10), *(10, 2)))])\n" + " LogicalValues(tuples=[[{ 0 }]])\n", + "SELECT \"foo\"(10 * 10 + 10 * 2)\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")")) + .build().stream().map(x -> new Object[] { x.sparkSql, x.explain, x.trinoSql }).iterator(); + } + + private static class Spark2TrinoDataProvider { + private final String sparkSql; + private final String explain; + private final String trinoSql; + + public Spark2TrinoDataProvider(@Language("SQL") String sparkSql, String explain, @Language("SQL") String trinoSql) { + this.sparkSql = sparkSql; + this.explain = explain; + this.trinoSql = trinoSql; + } + } + + //TODO: Add unsupported SQL tests + + @Test(dataProvider = "support") + public void testSupport(String trinoSql, String expectedRelString, String expectedSql) throws Exception { + RelNode relNode = sparkToRelConverter.convertSql(trinoSql); + assertEquals(expectedRelString, RelOptUtil.toString(relNode)); + + RelToTrinoConverter relToTrinoConverter = new RelToTrinoConverter(new HiveMscAdapter(Hive.get(conf).getMSC())); + // Convert rel node back to Sql + String expandedSql = relToTrinoConverter.convert(relNode); + assertEquals(expectedSql, expandedSql); + } + +} diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/ToRelTestUtils.java b/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/ToRelTestUtils.java new file mode 100644 index 000000000..f34e1f514 --- /dev/null +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/ToRelTestUtils.java @@ -0,0 +1,74 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.util.UUID; + +import org.apache.commons.io.FileUtils; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.MetaException; +import org.apache.hadoop.hive.ql.CommandNeedRetryException; +import org.apache.hadoop.hive.ql.Driver; +import org.apache.hadoop.hive.ql.metadata.Hive; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse; +import org.apache.hadoop.hive.ql.session.SessionState; + +import com.linkedin.coral.common.HiveMscAdapter; + + +public class ToRelTestUtils { + public static final String CORAL_FROM_TRINO_TEST_DIR = "coral.spark.test.dir"; + + private static HiveMscAdapter hiveMetastoreClient; + public static SparkToRelConverter sparkToRelConverter; + + static void run(Driver driver, String sql) { + while (true) { + try { + CommandProcessorResponse result = driver.run(sql); + if (result.getException() != null) { + throw new RuntimeException("Execution failed for: " + sql, result.getException()); + } + } catch (CommandNeedRetryException e) { + continue; + } + break; + } + } + + public static void initializeViews(HiveConf conf) throws HiveException, MetaException, IOException { + String testDir = conf.get(CORAL_FROM_TRINO_TEST_DIR) + UUID.randomUUID(); + System.out.println("Test Workspace: " + testDir); + FileUtils.deleteDirectory(new File(testDir)); + SessionState.start(conf); + Driver driver = new Driver(conf); + hiveMetastoreClient = new HiveMscAdapter(Hive.get(conf).getMSC()); + sparkToRelConverter = new SparkToRelConverter(hiveMetastoreClient); + + // Views and tables used in TrinoToTrinoConverterTest + run(driver, "CREATE DATABASE IF NOT EXISTS default"); + run(driver, "CREATE TABLE IF NOT EXISTS default.foo(show int, a int, b int, x date, y date)"); + run(driver, "CREATE TABLE IF NOT EXISTS default.my_table(x array, y array>, z int)"); + run(driver, "CREATE TABLE IF NOT EXISTS default.a(b int, id int, x int)"); + run(driver, "CREATE TABLE IF NOT EXISTS default.b(foobar int, id int, y int)"); + } + + public static HiveConf loadResourceHiveConf() { + InputStream hiveConfStream = ToRelTestUtils.class.getClassLoader().getResourceAsStream("hive.xml"); + HiveConf hiveConf = new HiveConf(); + hiveConf.set(CORAL_FROM_TRINO_TEST_DIR, + System.getProperty("java.io.tmpdir") + "/coral/trino/" + UUID.randomUUID().toString()); + hiveConf.addResource(hiveConfStream); + hiveConf.set("mapreduce.framework.name", "local-trino"); + hiveConf.set("_hive.hdfs.session.path", "/tmp/coral/trino"); + hiveConf.set("_hive.local.session.path", "/tmp/coral/trino"); + return hiveConf; + } +}