From a77bf29eabb4a060a11a8a962c150db2a9c8d191 Mon Sep 17 00:00:00 2001 From: Yuya Ebihara Date: Fri, 13 Jan 2023 12:11:06 +0900 Subject: [PATCH] Support translating Spark SQL to Calcite RelNode --- .../test_incremental.cpython-312.pyc | Bin 0 -> 4439 bytes coral-spark/build.gradle | 5 + .../coral/spark/AddExplicitAlias.java | 9 +- .../com/linkedin/coral/spark/CoralSpark.java | 2 +- .../CoralSqlNodeToSparkSqlNodeConverter.java | 2 +- .../spark/CoralToSparkSqlCallConverter.java | 2 +- .../spark/IRRelToSparkRelTransformer.java | 2 +- .../coral/spark/SparkSqlRewriter.java | 2 +- .../coral/spark/containers/SparkRelInfo.java | 2 +- .../coral/spark/containers/SparkUDFInfo.java | 2 +- .../coral/spark/dialect/SparkSqlDialect.java | 2 +- .../exceptions/UnsupportedUDFException.java | 2 +- .../coral/spark/functions/SqlLateralJoin.java | 2 +- .../functions/SqlLateralViewAsOperator.java | 2 +- .../spark/spark2rel/OperatorTransformer.java | 354 ++++++++++++++++ .../Spark2CoralOperatorConverter.java | 35 ++ .../Spark2CoralOperatorTransformerMap.java | 34 ++ ...park2CoralOperatorTransformerMapUtils.java | 81 ++++ .../spark/spark2rel/SparkSqlConformance.java | 20 + .../spark2rel/SparkSqlToRelConverter.java | 113 ++++++ .../spark/spark2rel/SparkToRelConverter.java | 98 +++++ .../spark/spark2rel/SparkViewExpander.java | 51 +++ .../parsetree/SparkParserDriver.java | 28 ++ .../parsetree/SparkSqlAstBuilder.java | 20 + .../parsetree/SparkSqlAstVisitor.java | 381 ++++++++++++++++++ .../parsetree/UnhandledASTNodeException.java | 23 ++ .../ExtractUnionFunctionTransformer.java | 2 +- .../FallBackToLinkedInHiveUDFTransformer.java | 2 +- .../FuzzyUnionGenericProjectTransformer.java | 2 +- .../transformers/TransportUDFTransformer.java | 2 +- .../spark/TransportUDFTransformerTest.java | 2 +- .../spark/TransportUDFTransformerTest.java | 2 +- .../coral/hive/hive2rel/CoralTestUDF.java | 2 +- .../coral/hive/hive2rel/CoralTestUDF2.java | 2 +- .../coral/hive/hive2rel/CoralTestUDTF.java | 2 +- .../hive/hive2rel/CoralTestUdfSquare.java | 2 +- .../hive2rel/CoralTestUnsupportedUDF.java | 2 +- .../linkedin/coral/spark/CoralSparkTest.java | 2 +- .../coral/spark/FuzzyUnionViewTest.java | 2 +- .../com/linkedin/coral/spark/TestUtils.java | 2 +- .../spark2rel/SparkToRelConverterTest.java | 174 ++++++++ .../coral/spark/spark2rel/ToRelTestUtils.java | 74 ++++ 42 files changed, 1523 insertions(+), 27 deletions(-) create mode 100644 coral-dbt/src/main/resources/tests/__pycache__/test_incremental.cpython-312.pyc create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/OperatorTransformer.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorConverter.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMap.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMapUtils.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlConformance.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlToRelConverter.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverter.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkViewExpander.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkParserDriver.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstBuilder.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstVisitor.java create mode 100644 coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/UnhandledASTNodeException.java create mode 100644 coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverterTest.java create mode 100644 coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/ToRelTestUtils.java diff --git a/coral-dbt/src/main/resources/tests/__pycache__/test_incremental.cpython-312.pyc b/coral-dbt/src/main/resources/tests/__pycache__/test_incremental.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2bf805e656739ad1cc769a2a5dd5df773734b96 GIT binary patch literal 4439 zcmd59O-vi<^_xF?Y;Xt^FockJchwMw#3tlt)6HfZLgN zVA*OarK(q2wdpB)$RTajACwDHFMDinm3pz4R9MqgLh4@HlL=Wx;?(zMJhn-IReGp8 zh&S(j?|tuk^S<}zzXk$+0?#t_P5*lzA^$|eeWWU~`4EstL?arflPTxTO>sc;IzKN= z2^`UcDUl~16HRbDz3ku3ROujl{NX!Q9@W7Pe8Vlxjva?ry<%)fTdO9EeAR6104Zr-H^kAg<0S z5oo@FmX?IDX!&kX!@OxGbR*6+=oA4*!#4}I${aW*3s{~rYf`3rEI0{yIf76VY$&=8 z%JVsX&UG{J-N!X6F96@`uGZf!pCdUgCuq`Rev*VeG0Ur@GG;nsT9OHdE7^cT7t=Iu zQY~UGnij{xyw&W8D5iQ_r&h=nP!h>Fqw_Rrs=8vNRdz=);;OEK7sOPWCN1CO$jIeM z@r$O_cdO(%59nqbX9Z-wC5x*ol0P-aAzGD~WN-cHP_ zOpV4s5qwRXpqIp}5sf(Uz`q?R zzXRkEsfUNgJ*|ey8Il*w+Mb-SS2(zP@*?CO{JCKSVSs_aYy4v&Fa22DDP(d|ZMP@q znI!L#lf>Ze@f^8DIQaYY-9_ORxyOZt*w&$%t{OBTH4*6v4#e|h*xaJ4mD@T|9V6audTD26Jw3=IEI z4-9#QKYOL2{oYvEW66r`Bt^0OiZY+lGCBpkSy2`;s*dhM$}m+k5uc~#Y)WI_043}Y zfMvqk5l1gDEE+aIFB?!d$pccBS~qx~)U+uOPg@1CN+)0&yjSY9*KEvDg8Sp}Uwhno zwRsMfb}y4yB&Q@z0O)~|3FY3cMN{4-2|lm|oj+QyB|8_+AK8c0^I^_CzGe?0Gf-B4 zUCR@$A$xX_<{PD#2t+bOsuZ6EH*fL%i(4&`s6ow4+AbCmV>T?XFr1gY6p28|iL(@R z4y;B>m9VQ|G=t?ox<|Z4j0K%_f}&;K@*A1kX_kVjWmp2sB$qi$G^jppV76BS*bheB zOb_d-VPp@y=B-Gbho%`J0Qbq4a@#NVeX{S1;J2R1A3S%67liBj;46r;^Y(>qbF&P) z?6PA40*#x?a~eOx<#`QmY{tf+X9!SufbTqWwG!m6K+mciS*GjUbTqqU>HW^^yk{EXS-?BLTgiD&S=3bBe0#mN1=4WsgseTpJm_ z5jqz7&cygwNTbtgMmHnCQZ%ZYYG`OOR3j{34P6$jp2nhwt;nQFH5rF97IZ7v$ZRNj zRw#qj&42)+ctT=60bSQ=vJGX??ug5hlfIbQRuhB)W#$Oxzn$imWSCHBK_~)zwjTi& zORHl(6~6;IUn{*nyT)PbGcc(eDrKk5dBZ{3Bd-VD&T>)EX8aZa^vafZ3t~BV@Zn6M zX}$ke`RJ)9`SQuLE8Y#S_>R9Q_rCO#jzeqh1EuzXmG;&4vjtySZeEkSN^;lZ?zQgW zQulE2@~z?zr;D=-6;R6!@f(CBxJx`V%>%>yD}bZ81Q1uQ@*p+Fe;;2z;Fpjy#%ED% zjK3(@)I|XV$M{RaSA_V6g*~!h9lpWM0!^6VArC}N(mXTJFp<{0|Fc|>-y#=$ZMxWZzsah?WCI*xCc`L-~Q29Y-A$z-SN>_EpK3bu%0*# z)OgI^6Il#hK%^=ZS@hqGjRJ4z+BN^~CvMYt(xG-cy=vWHsCuF!T{}0j7)A;6zdm|!`;ZM8^|ALBOOcyPUe0bVh^TH*dXxeHAyK0njiVoPQ>NSe*`v zUH!!qW5r!@`TMeE=&VSVjH>XA3x}qY@S2fQ&s? zQNn)KZ!ct$2@^Y%<(W$)=hQ*xHwFHZv)d^C{5Q;D6}qTx4t|m)qK?>?BF^#2!NO8`nU=1;p|#xU#YY2B|)k( g!4a{!ko`jFE{m=Aul#K6Cu2VyuLwjuV1LSg06r`o8~^|S literal 0 HcmV?d00001 diff --git a/coral-spark/build.gradle b/coral-spark/build.gradle index f3fbf025c..6c1e27a0f 100644 --- a/coral-spark/build.gradle +++ b/coral-spark/build.gradle @@ -3,10 +3,15 @@ apply from: "spark_itest.gradle" dependencies { compile project(':coral-hive') compile project(':coral-schema') + compile('org.apache.spark:spark-sql_2.13:3.2.0') { + exclude group: 'com.fasterxml.jackson', module: 'jackson-bom' + exclude group: 'org.apache.avro', module: 'avro-mapred' + } compileOnly deps.'spark'.'sql' + testCompile project(':coral-trino') testCompile(deps.'hive'.'hive-exec-core') { exclude group: 'org.apache.avro', module: 'avro-tools' // These exclusions are required to prevent duplicate classes since we include diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/AddExplicitAlias.java b/coral-spark/src/main/java/com/linkedin/coral/spark/AddExplicitAlias.java index 97807a1ef..6545966a0 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/AddExplicitAlias.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/AddExplicitAlias.java @@ -1,5 +1,5 @@ /** - * Copyright 2022 LinkedIn Corporation. All rights reserved. + * Copyright 2022-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ @@ -11,7 +11,12 @@ import com.google.common.base.Preconditions; -import org.apache.calcite.sql.*; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlSelect; import org.apache.calcite.sql.fun.SqlStdOperatorTable; import org.apache.calcite.sql.parser.SqlParserPos; import org.apache.calcite.sql.util.SqlShuttle; diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java index 8ffba3f71..164b0f041 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSpark.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSqlNodeToSparkSqlNodeConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSqlNodeToSparkSqlNodeConverter.java index af3e4a0f6..2982f8f9e 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSqlNodeToSparkSqlNodeConverter.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralSqlNodeToSparkSqlNodeConverter.java @@ -1,5 +1,5 @@ /** - * Copyright 2022-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2022-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java index c8a86d35c..bc321338d 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/CoralToSparkSqlCallConverter.java @@ -1,5 +1,5 @@ /** - * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java index 97f2934a6..cc3a4fd3a 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/IRRelToSparkRelTransformer.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java index 5cfa3fb6a..b50b73ddb 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/SparkSqlRewriter.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkRelInfo.java b/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkRelInfo.java index 7e20bed7a..0b2a02def 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkRelInfo.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkRelInfo.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkUDFInfo.java b/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkUDFInfo.java index 1d229daa3..910a510a7 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkUDFInfo.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/containers/SparkUDFInfo.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/dialect/SparkSqlDialect.java b/coral-spark/src/main/java/com/linkedin/coral/spark/dialect/SparkSqlDialect.java index faf330749..5a6d71b5e 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/dialect/SparkSqlDialect.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/dialect/SparkSqlDialect.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/exceptions/UnsupportedUDFException.java b/coral-spark/src/main/java/com/linkedin/coral/spark/exceptions/UnsupportedUDFException.java index 4a33e8dba..23b861305 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/exceptions/UnsupportedUDFException.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/exceptions/UnsupportedUDFException.java @@ -1,5 +1,5 @@ /** - * Copyright 2019-2021 LinkedIn Corporation. All rights reserved. + * Copyright 2019-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralJoin.java b/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralJoin.java index ec6f128cc..e8b7b8913 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralJoin.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralJoin.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralViewAsOperator.java b/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralViewAsOperator.java index d0c08ecdb..edc4c30fc 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralViewAsOperator.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/functions/SqlLateralViewAsOperator.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/OperatorTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/OperatorTransformer.java new file mode 100644 index 000000000..88e48824d --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/OperatorTransformer.java @@ -0,0 +1,354 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +import javax.annotation.Nonnull; +import javax.annotation.Nullable; + +import com.google.gson.JsonArray; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParser; +import com.google.gson.JsonPrimitive; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.parser.SqlParserPos; + +import static com.linkedin.coral.common.calcite.CalciteUtil.createCall; +import static com.linkedin.coral.common.calcite.CalciteUtil.createLiteralBoolean; +import static com.linkedin.coral.common.calcite.CalciteUtil.createLiteralNumber; +import static com.linkedin.coral.common.calcite.CalciteUtil.createStringLiteral; + + +/** + * Object for transforming Operator from one SQL language to another SQL language at the SqlNode layer. + * + * Suppose f1(a1, a2, ..., an) in the first language can be computed by + * f2(b1, b2, ..., bm) in the second language as follows: + * (b1, b2, ..., bm) = g(a1, a2, ..., an) + * f1(a1, a2, ..., an) = h(f2(g(a1, a2, ..., an))) + * + * We need to define two transformation functions: + * - A vector function g for transforming all operands + * - A function h for transforming the result. + * + * This class will represent g and h as expressions in JSON format as follows: + * - Operators: +, -, *, /, and ^ + * - Operands: source operands and literal values + * + * There may also be situations where a function in one language can map to more than one functions in the other + * language depending on the set of input parameters. + * We define a set of matching functions to determine what function name is used. + * Currently, there is no use-case more complicated than matching a parameter string to a static regex. + * + * Example 1: + * In the input IR, TRUNCATE(aDouble, numDigitAfterDot) truncates aDouble by removing + * any digit from the position numDigitAfterDot after the dot, like truncate(11.45, 0) = 11, + * truncate(11.45, 1) = 11.4 + * + * In the target IR, TRUNCATE(aDouble) only takes one argument and removes all digits after the dot, + * like truncate(11.45) = 11. + * + * The transformation of TRUNCATE from one IR to another is represented as follows: + * 1. Target IR name: TRUNCATE + * + * 2. Operand transformers: + * g(b1) = a1 * 10 ^ a2, with JSON format: + * [ + * { "op":"*", + * "operands":[ + * {"input":1}, // input 0 is reserved for result transformer. source inputs start from 1 + * { "op":"^", + * "operands":[ + * {"value":10}, + * {"input":2}]}]}] + * + * 3. Result transformer: + * h(result) = result / 10 ^ a2 + * { "op":"/", + * "operands":[ + * {"input":0}, // input 0 is for result transformer + * { "op":"^", + * "operands":[ + * {"value":10}, + * {"input":2}]}]}] + * + * + * 4. Operator transformers: + * none + * + * Example 2: + * In the input IR, there exists a hive-derived function to decode binary data given a format, DECODE(binary, scheme). + * In the target IR, there is no generic decoding function that takes a decoding-scheme. + * Instead, there exist specific decoding functions that are first-class functions like FROM_UTF8(binary). + * Consequently, we would need to know the operands in the function in order to determine the corresponding call. + * + * The transformation of DECODE from one IR to another is represented as follows: + * 1. Target IR name: There is no function name determined at compile time. + * null + * + * 2. Operand transformers: We want to retain column 1 and drop column 2: + * [{"input":1}] + * + * 3. Result transformer: No transformation is performed on output. + * null + * + * 4. Operator transformers: Check the second parameter (scheme) matches 'utf-8' with any casing using Java Regex. + * [ { + * "regex" : "^.*(?i)(utf-8).*$", + * "input" : 2, + * "name" : "from_utf8" + * } + * ] + */ +class OperatorTransformer { + private static final Map OP_MAP = new HashMap<>(); + + // Operators allowed in the transformation + static { + OP_MAP.put("+", SqlStdOperatorTable.PLUS); + OP_MAP.put("-", SqlStdOperatorTable.MINUS); + OP_MAP.put("*", SqlStdOperatorTable.MULTIPLY); + OP_MAP.put("/", SqlStdOperatorTable.DIVIDE); + OP_MAP.put("^", SqlStdOperatorTable.POWER); + OP_MAP.put("%", SqlStdOperatorTable.MOD); + } + + public static final String OPERATOR = "op"; + public static final String OPERANDS = "operands"; + /** + * For input node: + * - input equals 0 refers to the result + * - input great than 0 refers to the index of source operand (starting from 1) + */ + public static final String INPUT = "input"; + public static final String VALUE = "value"; + public static final String REGEX = "regex"; + public static final String NAME = "name"; + + public final String fromOperatorName; + public final SqlOperator targetOperator; + public final List operandTransformers; + public final JsonObject resultTransformer; + public final List operatorTransformers; + + private OperatorTransformer(String fromOperatorName, SqlOperator targetOperator, List operandTransformers, + JsonObject resultTransformer, List operatorTransformers) { + this.fromOperatorName = fromOperatorName; + this.targetOperator = targetOperator; + this.operandTransformers = operandTransformers; + this.resultTransformer = resultTransformer; + this.operatorTransformers = operatorTransformers; + } + + /** + * Creates a new transformer. + * + * @param fromOperatorName Name of the function associated with this Operator in the input IR + * @param targetOperator Operator in the target language + * @param operandTransformers JSON string representing the operand transformations, + * null for identity transformations + * @param resultTransformer JSON string representing the result transformation, + * null for identity transformation + * @param operatorTransformers JSON string representing an array of transformers that can vary the name of the target + * operator based on runtime parameter values. + * In the order of the JSON Array, the first transformer that matches the JSON string will + * have its given operator named selected as the target operator name. + * Operands are indexed beginning at index 1. + * An operatorTransformer has the following serialized JSON string format: + * "[ + * { + * \"name\" : \"{Name of function if this matches}\", + * \"input\" : {Index of the parameter starting at index 1 that is evaluated }, + * \"regex\" : \"{Java Regex string matching the parameter at given input}\" + * }, + * ... + * ]" + * For example, a transformer for a operator named "foo" when parameter 2 matches exactly + * "bar" is specified as: + * "[ + * { + * \"name\" : \"foo\", + * \"input\" : 2, + * \"regex\" : \"'bar'\" + * } + * ]" + * NOTE: A string literal is represented exactly as ['STRING_LITERAL'] with the single + * quotation marks INCLUDED. + * As seen in the example above, the single quotation marks are also present in the + * regex matcher. + * + * @return {@link OperatorTransformer} object + */ + + public static OperatorTransformer of(@Nonnull String fromOperatorName, @Nonnull SqlOperator targetOperator, + @Nullable String operandTransformers, @Nullable String resultTransformer, @Nullable String operatorTransformers) { + List operands = null; + JsonObject result = null; + List operators = null; + if (operandTransformers != null) { + operands = parseJsonObjectsFromString(operandTransformers); + } + if (resultTransformer != null) { + result = new JsonParser().parse(resultTransformer).getAsJsonObject(); + } + if (operatorTransformers != null) { + operators = parseJsonObjectsFromString(operatorTransformers); + } + return new OperatorTransformer(fromOperatorName, targetOperator, operands, result, operators); + } + + /** + * Transforms a call to the source operator. + * + * @param sourceOperands Source operands + * @return An expression calling the target operator that is equivalent to the source operator call + */ + public SqlNode transformCall(List sourceOperands) { + final SqlOperator newTargetOperator = transformTargetOperator(targetOperator, sourceOperands); + if (newTargetOperator == null || newTargetOperator.getName().isEmpty()) { + String operands = sourceOperands.stream().map(SqlNode::toString).collect(Collectors.joining(",")); + throw new IllegalArgumentException( + String.format("An equivalent operator in the target IR was not found for the function call: %s(%s)", + fromOperatorName, operands)); + } + final List newOperands = transformOperands(sourceOperands); + final SqlCall newCall = createCall(newTargetOperator, newOperands, SqlParserPos.ZERO); + return transformResult(newCall, sourceOperands); + } + + private List transformOperands(List sourceOperands) { + if (operandTransformers == null) { + return sourceOperands; + } + final List sources = new ArrayList<>(); + // Add a dummy expression for input 0 + sources.add(null); + sources.addAll(sourceOperands); + final List results = new ArrayList<>(); + for (JsonObject operandTransformer : operandTransformers) { + results.add(transformExpression(operandTransformer, sources)); + } + return results; + } + + private SqlNode transformResult(SqlNode result, List sourceOperands) { + if (resultTransformer == null) { + return result; + } + final List sources = new ArrayList<>(); + // Result will be input 0 + sources.add(result); + sources.addAll(sourceOperands); + return transformExpression(resultTransformer, sources); + } + + /** + * Performs a single transformer. + */ + private SqlNode transformExpression(JsonObject transformer, List sourceOperands) { + if (transformer.get(OPERATOR) != null) { + final List inputOperands = new ArrayList<>(); + for (JsonElement inputOperand : transformer.getAsJsonArray(OPERANDS)) { + if (inputOperand.isJsonObject()) { + inputOperands.add(transformExpression(inputOperand.getAsJsonObject(), sourceOperands)); + } + } + final String operatorName = transformer.get(OPERATOR).getAsString(); + final SqlOperator op = OP_MAP.get(operatorName); + if (op == null) { + throw new UnsupportedOperationException("Operator " + operatorName + " is not supported in transformation"); + } + return createCall(op, inputOperands, SqlParserPos.ZERO); + } + if (transformer.get(INPUT) != null) { + int index = transformer.get(INPUT).getAsInt(); + if (index < 0 || index >= sourceOperands.size() || sourceOperands.get(index) == null) { + throw new IllegalArgumentException( + "Invalid input value: " + index + ". Number of source operands: " + sourceOperands.size()); + } + return sourceOperands.get(index); + } + final JsonElement value = transformer.get(VALUE); + if (value == null) { + throw new IllegalArgumentException("JSON node for transformation should be either op, input, or value"); + } + if (!value.isJsonPrimitive()) { + throw new IllegalArgumentException("Value should be of primitive type: " + value); + } + + final JsonPrimitive primitive = value.getAsJsonPrimitive(); + if (primitive.isString()) { + return createStringLiteral(primitive.getAsString(), SqlParserPos.ZERO); + } + if (primitive.isBoolean()) { + return createLiteralBoolean(primitive.getAsBoolean(), SqlParserPos.ZERO); + } + if (primitive.isNumber()) { + return createLiteralNumber(value.getAsBigDecimal().longValue(), SqlParserPos.ZERO); + } + + throw new UnsupportedOperationException("Invalid JSON literal value: " + primitive); + } + + /** + * Returns a SqlOperator with a function name based on the value of the source operands. + */ + private SqlOperator transformTargetOperator(SqlOperator operator, List sourceOperands) { + if (operatorTransformers == null) { + return operator; + } + + for (JsonObject operatorTransformer : operatorTransformers) { + if (!operatorTransformer.has(REGEX) || !operatorTransformer.has(INPUT) || !operatorTransformer.has(NAME)) { + throw new IllegalArgumentException( + "JSON node for target operator transformer must have a matcher, input and name"); + } + // We use the same convention as operand and result transformers. + // Therefore, we start source index values at index 1 instead of index 0. + // Acceptable index values are set to be [1, size] + int index = operatorTransformer.get(INPUT).getAsInt() - 1; + if (index < 0 || index >= sourceOperands.size()) { + throw new IllegalArgumentException( + String.format("Index is not within the acceptable range [%d, %d]", 1, sourceOperands.size())); + } + String functionName = operatorTransformer.get(NAME).getAsString(); + if (functionName.isEmpty()) { + throw new IllegalArgumentException("JSON node for transformation must have a non-empty name"); + } + String matcher = operatorTransformer.get(REGEX).getAsString(); + + if (Pattern.matches(matcher, sourceOperands.get(index).toString())) { + return Spark2CoralOperatorTransformerMapUtils.createOperator(functionName, operator.getReturnTypeInference(), + null); + } + } + return operator; + } + + /** + * Creates an ArrayList of JsonObjects from a string input. + * The input string must be a serialized JSON array. + */ + private static List parseJsonObjectsFromString(String s) { + List objects = new ArrayList<>(); + JsonArray transformerArray = new JsonParser().parse(s).getAsJsonArray(); + for (JsonElement object : transformerArray) { + objects.add(object.getAsJsonObject()); + } + return objects; + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorConverter.java new file mode 100644 index 000000000..9510df33b --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorConverter.java @@ -0,0 +1,35 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.Locale; + +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.util.SqlShuttle; + + +/** + * Rewrites the SqlNode tree to replace Spark SQL operators with Coral IR to obtain a Coral-compatible plan. + */ +public class Spark2CoralOperatorConverter extends SqlShuttle { + public Spark2CoralOperatorConverter() { + } + + @Override + public SqlNode visit(final SqlCall call) { + final String operatorName = call.getOperator().getName(); + + final OperatorTransformer transformer = Spark2CoralOperatorTransformerMap + .getOperatorTransformer(operatorName.toLowerCase(Locale.ROOT), call.operandCount()); + + if (transformer == null) { + return super.visit(call); + } + + return super.visit((SqlCall) transformer.transformCall(call.getOperandList())); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMap.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMap.java new file mode 100644 index 000000000..68b3ee836 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMap.java @@ -0,0 +1,34 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.HashMap; +import java.util.Map; + +import static com.linkedin.coral.spark.spark2rel.Spark2CoralOperatorTransformerMapUtils.getKey; + + +public class Spark2CoralOperatorTransformerMap { + private Spark2CoralOperatorTransformerMap() { + } + + public static final Map TRANSFORMER_MAP = new HashMap<>(); + + static { + // TODO: keep adding Spark-Specific functions as needed + } + + /** + * Gets SparkCalciteOperatorTransformer for a given Spark SQL Operator. + * + * @param sparkOpName Name of Spark SQL operator + * @param numOperands Number of operands + * @return {@link OperatorTransformer} object + */ + public static OperatorTransformer getOperatorTransformer(String sparkOpName, int numOperands) { + return TRANSFORMER_MAP.get(getKey(sparkOpName, numOperands)); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMapUtils.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMapUtils.java new file mode 100644 index 000000000..d9fa188e0 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/Spark2CoralOperatorTransformerMapUtils.java @@ -0,0 +1,81 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.Map; + +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlOperator; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlOperandTypeChecker; +import org.apache.calcite.sql.type.SqlReturnTypeInference; +import org.apache.calcite.sql.validate.SqlUserDefinedFunction; + + +public class Spark2CoralOperatorTransformerMapUtils { + + private Spark2CoralOperatorTransformerMapUtils() { + } + + /** + * Creates a mapping for Spark SqlOperator name to SparkCalciteOperatorTransformer. + * + * @param transformerMap Map to store the result + * @param sparkOp Spark SQL operator + * @param numOperands Number of operands + * @param calciteOperatorName Name of Calcite Operator + */ + static void createTransformerMapEntry(Map transformerMap, SqlOperator sparkOp, + int numOperands, String calciteOperatorName) { + createTransformerMapEntry(transformerMap, sparkOp, numOperands, calciteOperatorName, null, null); + } + + /** + * Creates a mapping from Spark SqlOperator name to Calcite Operator with Calcite Operator name, operands transformer, and result transformers. + * To construct Calcite SqlOperator from Calcite Operator name, this method reuses the return type inference from sparkOp, + * assuming equivalence. + * + * @param transformerMap Map to store the result + * @param sparkOp Spark SQL operator + * @param numOperands Number of operands + * @param calciteOperatorName Name of Calcite Operator + * @param operandTransformer Operand transformers, null for identity transformation + * @param resultTransformer Result transformer, null for identity transformation + */ + static void createTransformerMapEntry(Map transformerMap, SqlOperator sparkOp, + int numOperands, String calciteOperatorName, String operandTransformer, String resultTransformer) { + createTransformerMapEntry(transformerMap, sparkOp, numOperands, + createOperator(calciteOperatorName, sparkOp.getReturnTypeInference(), sparkOp.getOperandTypeChecker()), + operandTransformer, resultTransformer); + } + + /** + * Creates a mapping from Spark SqlOperator name to Calcite UDF with Calcite SqlOperator, operands transformer, and result transformers. + * + * @param transformerMap Map to store the result + * @param sparkOp Spark SQL operator + * @param numOperands Number of operands + * @param calciteSqlOperator The Calcite Sql Operator that is used as the target operator in the map + * @param operandTransformer Operand transformers, null for identity transformation + * @param resultTransformer Result transformer, null for identity transformation + */ + static void createTransformerMapEntry(Map transformerMap, SqlOperator sparkOp, + int numOperands, SqlOperator calciteSqlOperator, String operandTransformer, String resultTransformer) { + + transformerMap.put(getKey(sparkOp.getName(), numOperands), + OperatorTransformer.of(sparkOp.getName(), calciteSqlOperator, operandTransformer, resultTransformer, null)); + } + + static SqlOperator createOperator(String functionName, SqlReturnTypeInference returnTypeInference, + SqlOperandTypeChecker operandTypeChecker) { + return new SqlUserDefinedFunction(new SqlIdentifier(functionName, SqlParserPos.ZERO), returnTypeInference, null, + operandTypeChecker, null, null); + } + + static String getKey(String sparkOpName, int numOperands) { + return sparkOpName + "_" + numOperands; + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlConformance.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlConformance.java new file mode 100644 index 000000000..d18605ea9 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlConformance.java @@ -0,0 +1,20 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import org.apache.calcite.sql.validate.SqlConformance; +import org.apache.calcite.sql.validate.SqlConformanceEnum; +import org.apache.calcite.sql.validate.SqlDelegatingConformance; + + +public class SparkSqlConformance extends SqlDelegatingConformance { + + public static final SqlConformance SPARK_SQL = new SparkSqlConformance(); + + private SparkSqlConformance() { + super(SqlConformanceEnum.PRAGMATIC_2003); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlToRelConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlToRelConverter.java new file mode 100644 index 000000000..2e7da2c87 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkSqlToRelConverter.java @@ -0,0 +1,113 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.ArrayList; +import java.util.List; + +import org.apache.calcite.linq4j.Ord; +import org.apache.calcite.plan.Convention; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.prepare.Prepare; +import org.apache.calcite.rel.RelCollation; +import org.apache.calcite.rel.RelCollations; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.core.Uncollect; +import org.apache.calcite.rel.logical.LogicalValues; +import org.apache.calcite.rel.metadata.JaninoRelMetadataProvider; +import org.apache.calcite.rel.metadata.RelMetadataQuery; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.rex.RexNode; +import org.apache.calcite.sql.SqlCall; +import org.apache.calcite.sql.SqlExplainFormat; +import org.apache.calcite.sql.SqlExplainLevel; +import org.apache.calcite.sql.SqlKind; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlUnnestOperator; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql2rel.SqlRexConvertletTable; +import org.apache.calcite.sql2rel.SqlToRelConverter; + +import com.linkedin.coral.hive.hive2rel.functions.HiveExplodeOperator; +import com.linkedin.coral.hive.hive2rel.rel.HiveUncollect; + + +/** + * Class to convert Spark SQL to Calcite RelNode. This class + * specializes the functionality provided by {@link SqlToRelConverter}. + */ +class SparkSqlToRelConverter extends SqlToRelConverter { + + SparkSqlToRelConverter(RelOptTable.ViewExpander viewExpander, SqlValidator validator, + Prepare.CatalogReader catalogReader, RelOptCluster cluster, SqlRexConvertletTable convertletTable, + Config config) { + super(viewExpander, validator, catalogReader, cluster, convertletTable, config); + } + + // This differs from base class in two ways: + // 1. This does not validate the type of converted rel rowType with that of validated node. This is because + // hive is lax in enforcing view schemas. + // 2. This skips calling some methods because (1) those are private, and (2) not required for our usecase + public RelRoot convertQuery(SqlNode query, final boolean needsValidation, final boolean top) { + if (needsValidation) { + query = validator.validate(query); + } + + RelMetadataQuery.THREAD_PROVIDERS.set(JaninoRelMetadataProvider.of(cluster.getMetadataProvider())); + RelNode result = convertQueryRecursive(query, top, null).rel; + RelCollation collation = RelCollations.EMPTY; + + if (SQL2REL_LOGGER.isDebugEnabled()) { + SQL2REL_LOGGER.debug(RelOptUtil.dumpPlan("Plan after converting SqlNode to RelNode", result, + SqlExplainFormat.TEXT, SqlExplainLevel.EXPPLAN_ATTRIBUTES)); + } + + final RelDataType validatedRowType = validator.getValidatedNodeType(query); + return RelRoot.of(result, validatedRowType, query.getKind()).withCollation(collation); + } + + @Override + protected void convertFrom(Blackboard bb, SqlNode from) { + if (from == null) { + super.convertFrom(bb, from); + return; + } + switch (from.getKind()) { + case UNNEST: + convertUnnestFrom(bb, from); + break; + default: + super.convertFrom(bb, from); + break; + } + } + + private void convertUnnestFrom(Blackboard bb, SqlNode from) { + final SqlCall call; + call = (SqlCall) from; + final List nodes = call.getOperandList(); + final SqlUnnestOperator operator = (SqlUnnestOperator) call.getOperator(); + // FIXME: base class calls 'replaceSubqueries for operands here but that's a private + // method. This is not an issue for our usecases with hive but we may need handling in future + final List exprs = new ArrayList<>(); + final List fieldNames = new ArrayList<>(); + for (Ord node : Ord.zip(nodes)) { + exprs.add(bb.convertExpression(node.e)); + // In Hive, "LATERAL VIEW EXPLODE(arr) t" is equivalent to "LATERAL VIEW EXPLODE(arr) t AS col". + // Use the default column name "col" if not specified. + fieldNames.add(node.e.getKind() == SqlKind.AS ? validator.deriveAlias(node.e, node.i) + : HiveExplodeOperator.ARRAY_ELEMENT_COLUMN_NAME); + } + final RelNode input = RelOptUtil.createProject((null != bb.root) ? bb.root : LogicalValues.createOneRow(cluster), + exprs, fieldNames, true); + Uncollect uncollect = + new HiveUncollect(cluster, cluster.traitSetOf(Convention.NONE), input, operator.withOrdinality); + bb.setRoot(uncollect, true); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverter.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverter.java new file mode 100644 index 000000000..f0a49cfc9 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverter.java @@ -0,0 +1,98 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +import org.apache.calcite.adapter.java.JavaTypeFactory; +import org.apache.calcite.plan.RelOptCluster; +import org.apache.calcite.plan.volcano.VolcanoPlanner; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlOperatorTable; +import org.apache.calcite.sql.fun.SqlStdOperatorTable; +import org.apache.calcite.sql.util.ChainedSqlOperatorTable; +import org.apache.calcite.sql.validate.SqlValidator; +import org.apache.calcite.sql2rel.SqlRexConvertletTable; +import org.apache.calcite.sql2rel.SqlToRelConverter; +import org.apache.hadoop.hive.metastore.api.Table; + +import com.linkedin.coral.common.HiveMetastoreClient; +import com.linkedin.coral.common.HiveRelBuilder; +import com.linkedin.coral.common.ToRelConverter; +import com.linkedin.coral.hive.hive2rel.DaliOperatorTable; +import com.linkedin.coral.hive.hive2rel.HiveConvertletTable; +import com.linkedin.coral.hive.hive2rel.HiveSqlValidator; +import com.linkedin.coral.hive.hive2rel.functions.HiveFunctionResolver; +import com.linkedin.coral.hive.hive2rel.functions.StaticHiveFunctionRegistry; +import com.linkedin.coral.spark.spark2rel.parsetree.SparkParserDriver; + +import static com.linkedin.coral.spark.spark2rel.SparkSqlConformance.SPARK_SQL; + + +/* + * We provide this class as a public interface by providing a thin wrapper + * around SparkSqlToRelConverter. Directly using SparkSqlToRelConverter will + * expose public methods from SqlToRelConverter. Use of SqlToRelConverter + * is likely to change in the future if we want more control over the + * conversion process. This class abstracts that out. + */ +public class SparkToRelConverter extends ToRelConverter { + private final HiveFunctionResolver functionResolver = + new HiveFunctionResolver(new StaticHiveFunctionRegistry(), new ConcurrentHashMap<>()); + private final + // The validator must be reused + SqlValidator sqlValidator = new HiveSqlValidator(getOperatorTable(), getCalciteCatalogReader(), + ((JavaTypeFactory) getRelBuilder().getTypeFactory()), SPARK_SQL); + + public SparkToRelConverter(HiveMetastoreClient hiveMetastoreClient) { + super(hiveMetastoreClient); + } + + public SparkToRelConverter(Map>> localMetaStore) { + super(localMetaStore); + } + + @Override + protected SqlRexConvertletTable getConvertletTable() { + return new HiveConvertletTable(); + } + + @Override + protected SqlValidator getSqlValidator() { + return sqlValidator; + } + + @Override + protected SqlOperatorTable getOperatorTable() { + return ChainedSqlOperatorTable.of(SqlStdOperatorTable.instance(), new DaliOperatorTable(functionResolver)); + } + + @Override + protected SqlToRelConverter getSqlToRelConverter() { + return new SparkSqlToRelConverter(new SparkViewExpander(this), getSqlValidator(), getCalciteCatalogReader(), + RelOptCluster.create(new VolcanoPlanner(), getRelBuilder().getRexBuilder()), getConvertletTable(), + SqlToRelConverter.configBuilder().withRelBuilderFactory(HiveRelBuilder.LOGICAL_BUILDER).build()); + } + + @Override + protected SqlNode toSqlNode(String sql, Table sparkView) { + String trimmedSql = trimParenthesis(sql.toUpperCase()); + SqlNode parsedSqlNode = SparkParserDriver.parse(trimmedSql); + SqlNode convertedSqlNode = parsedSqlNode.accept(new Spark2CoralOperatorConverter()); + return convertedSqlNode; + } + + private static String trimParenthesis(String value) { + String str = value.trim(); + if (str.startsWith("(") && str.endsWith(")")) { + return trimParenthesis(str.substring(1, str.length() - 1)); + } + return str; + } + +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkViewExpander.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkViewExpander.java new file mode 100644 index 000000000..f4958c2d6 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/SparkViewExpander.java @@ -0,0 +1,51 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.util.List; + +import javax.annotation.Nonnull; + +import com.google.common.base.Preconditions; + +import org.apache.calcite.plan.RelOptTable; +import org.apache.calcite.rel.RelRoot; +import org.apache.calcite.rel.type.RelDataType; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.util.Util; + +import com.linkedin.coral.common.FuzzyUnionSqlRewriter; + + +/** + * Class that implements {@link RelOptTable.ViewExpander} + * interface to support expansion of Spark Views to relational algebra. + */ +public class SparkViewExpander implements RelOptTable.ViewExpander { + + private final SparkToRelConverter sparkToRelConverter; + /** + * Instantiates a new Spark view expander. + * + * @param sparkToRelConverter Spark to Rel converter + */ + public SparkViewExpander(@Nonnull SparkToRelConverter sparkToRelConverter) { + this.sparkToRelConverter = sparkToRelConverter; + } + + @Override + public RelRoot expandView(RelDataType rowType, String queryString, List schemaPath, List viewPath) { + Preconditions.checkNotNull(viewPath); + Preconditions.checkState(!viewPath.isEmpty()); + + String dbName = Util.last(schemaPath); + String tableName = viewPath.get(0); + + SqlNode sqlNode = sparkToRelConverter.processView(dbName, tableName) + .accept(new FuzzyUnionSqlRewriter(tableName, sparkToRelConverter)); + return sparkToRelConverter.getSqlToRelConverter().convertQuery(sqlNode, true, true); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkParserDriver.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkParserDriver.java new file mode 100644 index 000000000..21fa4b026 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkParserDriver.java @@ -0,0 +1,28 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel.parsetree; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.apache.calcite.sql.SqlNode; +import org.apache.spark.sql.catalyst.parser.SqlBaseParser; +import org.apache.spark.sql.execution.SparkSqlParser; + + +public class SparkParserDriver { + private static final SparkSqlParser SPARK_SQL_PARSER = new SparkSqlParser(); + + /** + * Use the SparkSqlParser to parse the command and return the Calcite SqlNode. + * + * @param command Spark SQL + * @return {@link SqlNode} as response + */ + public static SqlNode parse(String command) { + ParserRuleContext context = SPARK_SQL_PARSER.parse(command, new SparkSqlAstBuilder()); + SparkSqlAstVisitor visitor = new SparkSqlAstVisitor(); + return visitor.visitSingleStatement((SqlBaseParser.SingleStatementContext) context); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstBuilder.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstBuilder.java new file mode 100644 index 000000000..c8b64f1e1 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstBuilder.java @@ -0,0 +1,20 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel.parsetree; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.apache.spark.sql.catalyst.parser.AstBuilder; +import org.apache.spark.sql.catalyst.parser.SqlBaseParser; + +import scala.Function1; + + +public class SparkSqlAstBuilder extends AstBuilder implements Function1 { + @Override + public ParserRuleContext apply(SqlBaseParser v1) { + return v1.singleStatement(); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstVisitor.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstVisitor.java new file mode 100644 index 000000000..d9e666be8 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/SparkSqlAstVisitor.java @@ -0,0 +1,381 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel.parsetree; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import org.antlr.v4.runtime.ParserRuleContext; +import org.antlr.v4.runtime.tree.ParseTree; +import org.antlr.v4.runtime.tree.RuleNode; +import org.antlr.v4.runtime.tree.TerminalNode; +import org.apache.calcite.sql.SqlDataTypeSpec; +import org.apache.calcite.sql.SqlFunctionCategory; +import org.apache.calcite.sql.SqlIdentifier; +import org.apache.calcite.sql.SqlLiteral; +import org.apache.calcite.sql.SqlNode; +import org.apache.calcite.sql.SqlNodeList; +import org.apache.calcite.sql.SqlOrderBy; +import org.apache.calcite.sql.SqlSelect; +import org.apache.calcite.sql.SqlUnresolvedFunction; +import org.apache.calcite.sql.parser.SqlParserPos; +import org.apache.calcite.sql.type.SqlTypeFactoryImpl; +import org.apache.calcite.sql.type.SqlTypeName; +import org.apache.calcite.sql.type.SqlTypeUtil; +import org.apache.spark.sql.catalyst.parser.SqlBaseBaseVisitor; +import org.apache.spark.sql.catalyst.parser.SqlBaseParser; + +import com.linkedin.coral.common.HiveTypeSystem; +import com.linkedin.coral.common.calcite.CalciteUtil; + +import static com.linkedin.coral.common.calcite.CalciteUtil.createCall; +import static com.linkedin.coral.common.calcite.CalciteUtil.createSqlIdentifier; +import static com.linkedin.coral.common.calcite.CalciteUtil.createStarIdentifier; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.AS; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.BETWEEN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.CAST; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.CONCAT; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.DIVIDE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.DIVIDE_INTEGER; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.EQUALS; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.GREATER_THAN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.GREATER_THAN_OR_EQUAL; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IS_DISTINCT_FROM; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IS_FALSE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IS_NULL; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IS_TRUE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.IS_UNKNOWN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LESS_THAN; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LESS_THAN_OR_EQUAL; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.LIKE; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.MINUS; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.MULTIPLY; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.NOT; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.NOT_EQUALS; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.PERCENT_REMAINDER; +import static org.apache.calcite.sql.fun.SqlStdOperatorTable.PLUS; +import static org.apache.calcite.sql.parser.SqlParserPos.ZERO; + + +public class SparkSqlAstVisitor extends SqlBaseBaseVisitor { + private static final SqlTypeFactoryImpl SQL_TYPE_FACTORY = new SqlTypeFactoryImpl(new HiveTypeSystem()); + private static final String UNSUPPORTED_EXCEPTION_MSG = "%s at line %d column %d is not supported in the visit."; + + @Override + public SqlNode visitChildren(RuleNode node) { + if (node.getChildCount() == 2 && node.getChild(1) instanceof TerminalNode) { + return node.getChild(0).accept(this); + } + return super.visitChildren(node); + } + + @Override + public SqlNode visitQuery(SqlBaseParser.QueryContext ctx) { + SqlSelect select = (SqlSelect) ((SqlNodeList) visit(ctx.queryTerm())).get(0); + SqlNodeList orderBy = ctx.queryOrganization().order.isEmpty() ? null : new SqlNodeList( + ctx.queryOrganization().order.stream().map(this::visit).collect(Collectors.toList()), getPos(ctx)); + SqlNode limit = ctx.queryOrganization().limit == null ? null : visit(ctx.queryOrganization().limit); + if (orderBy != null || limit != null) { + return new SqlOrderBy(getPos(ctx), select, orderBy, null, limit); + } + return select; + } + + @Override + public SqlNode visitRegularQuerySpecification(SqlBaseParser.RegularQuerySpecificationContext ctx) { + SqlNodeList selectList = visitSelectClause(ctx.selectClause()); + SqlNode from = ctx.fromClause() == null ? null : visitFromClause(ctx.fromClause()); + SqlNode where = ctx.whereClause() == null ? null : visitWhereClause(ctx.whereClause()); + SqlNodeList groupBy = + ctx.aggregationClause() == null ? null : visitGroupByClause(ctx.aggregationClause().groupByClause); + SqlNode having = ctx.havingClause() == null ? null : visitHavingClause(ctx.havingClause()); + return new SqlSelect(getPos(ctx), null, selectList, from, where, groupBy, having, null, null, null, null); + } + + @Override + public SqlNode visitWhereClause(SqlBaseParser.WhereClauseContext ctx) { + return super.visitWhereClause(ctx); + } + + @Override + public SqlNode visitCtes(SqlBaseParser.CtesContext ctx) { + throw new UnhandledASTNodeException(ctx, UNSUPPORTED_EXCEPTION_MSG); + } + + @Override + public SqlNode visitJoinRelation(SqlBaseParser.JoinRelationContext ctx) { + throw new UnhandledASTNodeException(ctx, UNSUPPORTED_EXCEPTION_MSG); + } + + @Override + public SqlNode visitJoinType(SqlBaseParser.JoinTypeContext ctx) { + throw new UnhandledASTNodeException(ctx, UNSUPPORTED_EXCEPTION_MSG); + } + + @Override + public SqlNode visitJoinCriteria(SqlBaseParser.JoinCriteriaContext ctx) { + throw new UnhandledASTNodeException(ctx, UNSUPPORTED_EXCEPTION_MSG); + } + + @Override + public SqlNode visitPredicated(SqlBaseParser.PredicatedContext ctx) { + SqlNode expression = visit(ctx.valueExpression()); + if (ctx.predicate() != null) { + return withPredicate(expression, ctx.predicate()); + } + return expression; + } + + private SqlNode withPredicate(SqlNode expression, SqlBaseParser.PredicateContext ctx) { + SqlNode predicate = toPredicate(expression, ctx); + if (ctx.NOT() == null) { + return predicate; + } + return NOT.createCall(getPos(ctx), predicate); + } + + private SqlNode toPredicate(SqlNode expression, SqlBaseParser.PredicateContext ctx) { + SqlParserPos position = getPos(ctx); + int type = ctx.kind.getType(); + switch (type) { + case SqlBaseParser.BETWEEN: + return BETWEEN.createCall(position, expression, visit(ctx.right)); + case SqlBaseParser.IN: + return IN.createCall(position, expression, visit(ctx.right)); + case SqlBaseParser.LIKE: + return LIKE.createCall(position, expression, visit(ctx.right)); + case SqlBaseParser.RLIKE: + throw new UnsupportedOperationException("Unsupported predicate type: RLIKE"); + case SqlBaseParser.NULL: + return IS_NULL.createCall(position, expression); + case SqlBaseParser.TRUE: + return IS_TRUE.createCall(position, expression); + case SqlBaseParser.FALSE: + return IS_FALSE.createCall(position, expression); + case SqlBaseParser.UNKNOWN: + return IS_UNKNOWN.createCall(position, expression); + case SqlBaseParser.DISTINCT: + return IS_DISTINCT_FROM.createCall(position, expression, visit(ctx.right)); + default: + throw new UnsupportedOperationException("Unsupported predicate type:" + type); + } + } + + @Override + public SqlNode visitComparison(SqlBaseParser.ComparisonContext ctx) { + SqlParserPos position = getPos(ctx); + SqlNode left = visit(ctx.left); + SqlNode right = visit(ctx.right); + TerminalNode operator = (TerminalNode) ctx.comparisonOperator().getChild(0); + switch (operator.getSymbol().getType()) { + case SqlBaseParser.EQ: + return EQUALS.createCall(position, left, right); + case SqlBaseParser.NSEQ: + throw new UnsupportedOperationException("Unsupported operator: NSEQ (NullSafeEqual)"); + case SqlBaseParser.NEQ: + case SqlBaseParser.NEQJ: + return NOT_EQUALS.createCall(position, left, right); + case SqlBaseParser.LT: + return LESS_THAN.createCall(position, left, right); + case SqlBaseParser.LTE: + return LESS_THAN_OR_EQUAL.createCall(position, left, right); + case SqlBaseParser.GT: + return GREATER_THAN.createCall(position, left, right); + case SqlBaseParser.GTE: + return GREATER_THAN_OR_EQUAL.createCall(position, left, right); + } + throw new UnsupportedOperationException("visitComparison"); + } + + @Override + public SqlNodeList visitSelectClause(SqlBaseParser.SelectClauseContext ctx) { + return getChildSqlNodeList(ctx.namedExpressionSeq()); + } + + @Override + public SqlNodeList visitGroupByClause(SqlBaseParser.GroupByClauseContext ctx) { + return getChildSqlNodeList(ctx.children); + } + + @Override + public SqlNode visitTableIdentifier(SqlBaseParser.TableIdentifierContext ctx) { + return createSqlIdentifier(getPos(ctx), ctx.db.getText(), ctx.table.getText()); + } + + @Override + public SqlNode visitTableName(SqlBaseParser.TableNameContext ctx) { + if (ctx.tableAlias().children != null) { + List operands = new ArrayList<>(); + operands.add(visit(ctx.getChild(0))); + operands.addAll(visitTableAlias(ctx.tableAlias()).getList()); + return AS.createCall(getPos(ctx), operands); + } + return visitMultipartIdentifier(ctx.multipartIdentifier()); + } + + @Override + public SqlNodeList visitTableAlias(SqlBaseParser.TableAliasContext ctx) { + List operands = new ArrayList<>(); + operands.add(visit(ctx.getChild(0))); + if (ctx.children.size() > 1) { + operands.addAll(((SqlNodeList) visit(ctx.getChild(1))).getList()); + } + return new SqlNodeList(operands, getPos(ctx)); + } + + @Override + public SqlNode visitQuotedIdentifierAlternative(SqlBaseParser.QuotedIdentifierAlternativeContext ctx) { + return createSqlIdentifier(getPos(ctx), ctx.getText()); + } + + @Override + public SqlNode visitUnquotedIdentifier(SqlBaseParser.UnquotedIdentifierContext ctx) { + return createSqlIdentifier(getPos(ctx), ctx.getText()); + } + + @Override + public SqlNode visitMultipartIdentifier(SqlBaseParser.MultipartIdentifierContext ctx) { + return createSqlIdentifier(getPos(ctx), + ctx.parts.stream().map(part -> part.identifier().getText()).toArray(String[]::new)); + } + + @Override + public SqlNode visitCast(SqlBaseParser.CastContext ctx) { + SqlDataTypeSpec spec = toSqlDataTypeSpec(ctx); + return CAST.createCall(getPos(ctx), visit(ctx.expression()), spec); + } + + private SqlDataTypeSpec toSqlDataTypeSpec(SqlBaseParser.CastContext ctx) { + return SqlTypeUtil.convertTypeToSpec( + SQL_TYPE_FACTORY.createSqlType(SqlTypeName.valueOf((ctx.dataType().getText().toUpperCase())))); + } + + @Override + public SqlNode visitFunctionCall(SqlBaseParser.FunctionCallContext ctx) { + SqlIdentifier functionName = createSqlIdentifier(getPos(ctx), ctx.functionName().getText()); + SqlUnresolvedFunction unresolvedFunction = + new SqlUnresolvedFunction(functionName, null, null, null, null, SqlFunctionCategory.USER_DEFINED_FUNCTION); + List operands = ctx.argument.stream().map(this::visit).collect(Collectors.toList()); + return createCall(unresolvedFunction, operands, getPos(ctx)); + } + + @Override + public SqlNode visitStatementDefault(SqlBaseParser.StatementDefaultContext ctx) { + return ctx.query().accept(this); + } + + @Override + public SqlNode visitQueryTermDefault(SqlBaseParser.QueryTermDefaultContext ctx) { + return getChildSqlNodeList(ctx); + } + + @Override + public SqlNode visitSubscript(SqlBaseParser.SubscriptContext ctx) { + throw new UnhandledASTNodeException(ctx, UNSUPPORTED_EXCEPTION_MSG); + } + + @Override + public SqlNode visitNamedExpression(SqlBaseParser.NamedExpressionContext ctx) { + if (ctx.name != null) { + return AS.createCall(getPos(ctx), visit(ctx.getChild(0)), + createSqlIdentifier(getPos(ctx), ctx.name.identifier().getText())); + } + String text = ctx.getText(); + if (text.equals("*")) { + return createStarIdentifier(getPos(ctx)); + } + return super.visitNamedExpression(ctx); + } + + @Override + public SqlNode visitIdentifierList(SqlBaseParser.IdentifierListContext ctx) { + List identifiers = ctx.identifierSeq().ident.stream() + .map(identifier -> createSqlIdentifier(getPos(identifier), identifier.getText())).collect(Collectors.toList()); + return new SqlNodeList(identifiers, getPos(ctx)); + } + + @Override + public SqlNode visitColumnReference(SqlBaseParser.ColumnReferenceContext ctx) { + return createSqlIdentifier(getPos(ctx), ctx.getText()); + } + + @Override + public SqlNode visitArithmeticBinary(SqlBaseParser.ArithmeticBinaryContext ctx) { + SqlNode left = visit(ctx.left); + SqlNode right = visit(ctx.right); + SqlParserPos position = getPos(ctx); + switch (ctx.operator.getType()) { + case SqlBaseParser.ASTERISK: + return MULTIPLY.createCall(position, left, right); + case SqlBaseParser.SLASH: + return DIVIDE.createCall(position, left, right); + case SqlBaseParser.PERCENT: + return PERCENT_REMAINDER.createCall(position, left, right); + case SqlBaseParser.DIV: + return DIVIDE_INTEGER.createCall(position, left, right); + case SqlBaseParser.PLUS: + return PLUS.createCall(position, left, right); + case SqlBaseParser.MINUS: + return MINUS.createCall(position, left, right); + case SqlBaseParser.CONCAT_PIPE: + return CONCAT.createCall(position, left, right); + case SqlBaseParser.AMPERSAND: + throw new UnsupportedOperationException("Unsupported arithmetic binary: &"); + case SqlBaseParser.HAT: + throw new UnsupportedOperationException("Unsupported arithmetic binary: ^"); + case SqlBaseParser.PIPE: + throw new UnsupportedOperationException("Unsupported arithmetic binary: |"); + } + throw new UnsupportedOperationException("Unsupported arithmetic binary: " + ctx.operator); + } + + @Override + public SqlNode visitBooleanLiteral(SqlBaseParser.BooleanLiteralContext ctx) { + boolean value = ctx.getText().equalsIgnoreCase("true"); + return CalciteUtil.createLiteralBoolean(value, getPos(ctx)); + } + + @Override + public SqlNode visitNumericLiteral(SqlBaseParser.NumericLiteralContext ctx) { + return SqlLiteral.createExactNumeric(ctx.getText(), getPos(ctx)); + } + + @Override + public SqlNode visitStringLiteral(SqlBaseParser.StringLiteralContext ctx) { + String text = ctx.getText().substring(1, ctx.getText().length() - 1); + return CalciteUtil.createStringLiteral(text, getPos(ctx)); + } + + private SqlNodeList getChildSqlNodeList(ParserRuleContext ctx) { + return new SqlNodeList(getChildren(ctx), getPos(ctx)); + } + + private SqlNodeList getChildSqlNodeList(List nodes) { + return new SqlNodeList(toListOfSqlNode(nodes), ZERO); + } + + private List getChildren(ParserRuleContext node) { + return toListOfSqlNode(node.children); + } + + private List toListOfSqlNode(List nodes) { + if (nodes == null) { + return Collections.emptyList(); + } + return nodes.stream().map(this::visit).filter(Objects::nonNull).collect(Collectors.toList()); + } + + private SqlParserPos getPos(ParserRuleContext ctx) { + if (ctx.start != null) { + return new SqlParserPos(ctx.start.getLine(), ctx.start.getStartIndex()); + } + return ZERO; + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/UnhandledASTNodeException.java b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/UnhandledASTNodeException.java new file mode 100644 index 000000000..611b11fc4 --- /dev/null +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/spark2rel/parsetree/UnhandledASTNodeException.java @@ -0,0 +1,23 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel.parsetree; + +import org.antlr.v4.runtime.ParserRuleContext; + + +public class UnhandledASTNodeException extends RuntimeException { + public UnhandledASTNodeException(ParserRuleContext context, String message) { + super(String.format(message, context.getClass().getSimpleName(), getLine(context), getColumn(context))); + } + + private static int getLine(ParserRuleContext context) { + return context.start.getLine(); + } + + private static int getColumn(ParserRuleContext context) { + return context.start.getStartIndex(); + } +} diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java index 27d6884b1..7a0b5fb4f 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/ExtractUnionFunctionTransformer.java @@ -1,5 +1,5 @@ /** - * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToLinkedInHiveUDFTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToLinkedInHiveUDFTransformer.java index a727ca37c..cda8ccb28 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToLinkedInHiveUDFTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FallBackToLinkedInHiveUDFTransformer.java @@ -1,5 +1,5 @@ /** - * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FuzzyUnionGenericProjectTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FuzzyUnionGenericProjectTransformer.java index 42ec14e8b..7655f954f 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FuzzyUnionGenericProjectTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/FuzzyUnionGenericProjectTransformer.java @@ -1,5 +1,5 @@ /** - * Copyright 2023 LinkedIn Corporation. All rights reserved. + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportUDFTransformer.java b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportUDFTransformer.java index 272fd197d..cf522b910 100644 --- a/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportUDFTransformer.java +++ b/coral-spark/src/main/java/com/linkedin/coral/spark/transformers/TransportUDFTransformer.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java b/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java index dc0ae839c..ba943fdcc 100644 --- a/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java +++ b/coral-spark/src/spark3test/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java b/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java index 9d68f3c6d..69e246662 100644 --- a/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java +++ b/coral-spark/src/sparktest/java/com/linkedin/coral/spark/TransportUDFTransformerTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF.java b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF.java index f4bb82491..2554ace8e 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF.java +++ b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2020 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF2.java b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF2.java index 1760817e8..dbcedb20a 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF2.java +++ b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDF2.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2020 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDTF.java b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDTF.java index 980c06621..e8fa7379e 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDTF.java +++ b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUDTF.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2022 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUdfSquare.java b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUdfSquare.java index 61031127a..1e0515e87 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUdfSquare.java +++ b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUdfSquare.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2020 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUnsupportedUDF.java b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUnsupportedUDF.java index 134119268..2a031915a 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUnsupportedUDF.java +++ b/coral-spark/src/test/java/com/linkedin/coral/hive/hive2rel/CoralTestUnsupportedUDF.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2020 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java index 666ea87e3..19e1c346c 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/CoralSparkTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/FuzzyUnionViewTest.java b/coral-spark/src/test/java/com/linkedin/coral/spark/FuzzyUnionViewTest.java index 0c3634b05..40a98f6b0 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/spark/FuzzyUnionViewTest.java +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/FuzzyUnionViewTest.java @@ -1,5 +1,5 @@ /** - * Copyright 2019-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2019-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java b/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java index 845e4ba39..1a97883d7 100644 --- a/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/TestUtils.java @@ -1,5 +1,5 @@ /** - * Copyright 2018-2023 LinkedIn Corporation. All rights reserved. + * Copyright 2018-2024 LinkedIn Corporation. All rights reserved. * Licensed under the BSD-2 Clause license. * See LICENSE in the project root for license information. */ diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverterTest.java b/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverterTest.java new file mode 100644 index 000000000..2710b72e5 --- /dev/null +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/SparkToRelConverterTest.java @@ -0,0 +1,174 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import java.util.Map; + +import com.google.common.collect.ImmutableList; + +import org.apache.calcite.plan.RelOptUtil; +import org.apache.calcite.rel.RelNode; +import org.apache.calcite.sql.type.ReturnTypes; +import org.apache.commons.io.FileUtils; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.MetaException; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.AfterTest; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import com.linkedin.coral.hive.hive2rel.functions.StaticHiveFunctionRegistry; +import com.linkedin.coral.trino.rel2trino.RelToTrinoConverter; + +import static com.linkedin.coral.spark.spark2rel.Spark2CoralOperatorTransformerMapUtils.createOperator; +import static com.linkedin.coral.spark.spark2rel.Spark2CoralOperatorTransformerMapUtils.createTransformerMapEntry; +import static com.linkedin.coral.spark.spark2rel.ToRelTestUtils.CORAL_FROM_TRINO_TEST_DIR; +import static com.linkedin.coral.spark.spark2rel.ToRelTestUtils.sparkToRelConverter; +import static org.apache.calcite.sql.type.OperandTypes.NILADIC; +import static org.apache.calcite.sql.type.OperandTypes.NUMERIC; +import static org.apache.calcite.sql.type.OperandTypes.NUMERIC_NUMERIC; +import static org.apache.calcite.sql.type.OperandTypes.or; +import static org.testng.AssertJUnit.assertEquals; + + +public class SparkToRelConverterTest { + private static HiveConf conf; + + @BeforeClass + public void beforeClass() throws HiveException, IOException, MetaException { + // Simulating a Coral environment where "foo" exists + StaticHiveFunctionRegistry.createAddUserDefinedFunction("foo", ReturnTypes.INTEGER, + or(NILADIC, NUMERIC, NUMERIC_NUMERIC)); + + conf = ToRelTestUtils.loadResourceHiveConf(); + ToRelTestUtils.initializeViews(conf); + + Map TRANSFORMER_MAP = Spark2CoralOperatorTransformerMap.TRANSFORMER_MAP; + + // foo(a) or foo() + createTransformerMapEntry(TRANSFORMER_MAP, createOperator("foo", ReturnTypes.INTEGER, or(NILADIC, NUMERIC)), 1, + "foo", null, null); + + // foo(a, b) => foo((10 * a) + (10 * b)) + createTransformerMapEntry(TRANSFORMER_MAP, createOperator("foo", ReturnTypes.INTEGER, NUMERIC_NUMERIC), 2, "foo", + "[{\"op\":\"+\",\"operands\":[{\"op\":\"*\",\"operands\":[{\"value\":10},{\"input\":1}]},{\"op\":\"*\",\"operands\":[{\"value\":10},{\"input\":2}]}]}]", + null); + } + + @AfterTest + public void afterClass() throws IOException { + FileUtils.deleteDirectory(new File(conf.get(CORAL_FROM_TRINO_TEST_DIR))); + } + + @DataProvider(name = "support") + public Iterator getSupportedSql() { + return ImmutableList. builder() + .add(new Spark2TrinoDataProvider("select * from foo", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT *\n" + "FROM \"default\".\"foo\"")) + .add(new Spark2TrinoDataProvider("select * from foo /* end */", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT *\n" + "FROM \"default\".\"foo\"")) + .add(new Spark2TrinoDataProvider("/* start */ select * from foo", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT *\n" + "FROM \"default\".\"foo\"")) + .add(new Spark2TrinoDataProvider("/* start */ select * /* middle */ from foo /* end */", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT *\n" + "FROM \"default\".\"foo\"")) + .add(new Spark2TrinoDataProvider("-- start \n select * -- junk -- hi\n from foo -- done", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT *\n" + "FROM \"default\".\"foo\"")) + .add(new Spark2TrinoDataProvider("select * from foo a (v, w, x, y, z)", + "LogicalProject(V=[$0], W=[$1], X=[$2], Y=[$3], Z=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT \"show\" AS \"V\", \"a\" AS \"W\", \"b\" AS \"X\", \"x\" AS \"Y\", \"y\" AS \"Z\"\n" + + "FROM \"default\".\"foo\"")) + .add(new Spark2TrinoDataProvider("select *, 123, * from foo", + "LogicalProject(show=[$0], a=[$1], b=[$2], x=[$3], y=[$4], EXPR$5=[123], show0=[$0], a0=[$1], b0=[$2], x0=[$3], y0=[$4])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT \"show\", \"a\", \"b\", \"x\", \"y\", 123, \"show\" AS \"show0\", \"a\" AS \"a0\", \"b\" AS \"b0\", \"x\" AS \"x0\", \"y\" AS \"y0\"\n" + + "FROM \"default\".\"foo\"")) + .add(new Spark2TrinoDataProvider("select show from foo", + "LogicalProject(SHOW=[$0])\n" + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT \"show\" AS \"SHOW\"\n" + "FROM \"default\".\"foo\"")) + .add(new Spark2TrinoDataProvider("select 1 + 13 || '15' from foo", + "LogicalProject(EXPR$0=[||(CAST(+(1, 13)):VARCHAR(65535) NOT NULL, '15')])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT CAST(1 + 13 AS VARCHAR(65535)) || '15'\n" + "FROM \"default\".\"foo\"")) + .add(new Spark2TrinoDataProvider("select x is distinct from y from foo where a is not distinct from b", + "LogicalProject(EXPR$0=[AND(OR(IS NOT NULL($3), IS NOT NULL($4)), IS NOT TRUE(=($3, $4)))])\n" + + " LogicalFilter(condition=[NOT(AND(OR(IS NOT NULL($1), IS NOT NULL($2)), IS NOT TRUE(=($1, $2))))])\n" + + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT (\"x\" IS NOT NULL OR \"y\" IS NOT NULL) AND \"x\" = \"y\" IS NOT TRUE\n" + + "FROM \"default\".\"foo\"\n" + + "WHERE NOT ((\"a\" IS NOT NULL OR \"b\" IS NOT NULL) AND \"a\" = \"b\" IS NOT TRUE)")) + .add(new Spark2TrinoDataProvider("select cast('123' as bigint)", + "LogicalProject(EXPR$0=[CAST('123'):BIGINT])\n" + " LogicalValues(tuples=[[{ 0 }]])\n", + "SELECT CAST('123' AS BIGINT)\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")")) + .add(new Spark2TrinoDataProvider("select a `my price` from `foo` `ORDERS`", + "LogicalProject(MY PRICE=[$1])\n" + " LogicalTableScan(table=[[hive, default, foo]])\n", + "SELECT \"a\" AS \"MY PRICE\"\n" + "FROM \"default\".\"foo\"")) + .add(new Spark2TrinoDataProvider("select * from a limit all", + "LogicalProject(b=[$0], id=[$1], x=[$2])\n" + " LogicalTableScan(table=[[hive, default, a]])\n", + "SELECT *\n" + "FROM \"default\".\"a\"")) + .add(new Spark2TrinoDataProvider("select * from a order by x limit all", + "LogicalSort(sort0=[$2], dir0=[ASC-nulls-first])\n" + " LogicalProject(b=[$0], id=[$1], x=[$2])\n" + + " LogicalTableScan(table=[[hive, default, a]])\n", + "SELECT *\n" + "FROM \"default\".\"a\"\n" + "ORDER BY \"x\" NULLS FIRST")) + .add(new Spark2TrinoDataProvider("select strpos('foobar', 'b') as pos", + "LogicalProject(POS=[instr('FOOBAR', 'B')])\n" + " LogicalValues(tuples=[[{ 0 }]])\n", + "SELECT \"strpos\"('FOOBAR', 'B') AS \"POS\"\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")")) + .add(new Spark2TrinoDataProvider("select foo(3)", + "LogicalProject(EXPR$0=[foo(3)])\n" + " LogicalValues(tuples=[[{ 0 }]])\n", + "SELECT \"foo\"(3)\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")")) + .add(new Spark2TrinoDataProvider("select FOO(3)", + "LogicalProject(EXPR$0=[foo(3)])\n" + " LogicalValues(tuples=[[{ 0 }]])\n", + "SELECT \"foo\"(3)\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")")) + .add(new Spark2TrinoDataProvider("select foo()", + "LogicalProject(EXPR$0=[foo()])\n" + " LogicalValues(tuples=[[{ 0 }]])\n", + "SELECT \"foo\"()\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")")) + .add(new Spark2TrinoDataProvider("select foo(10, 2)", + "LogicalProject(EXPR$0=[foo(+(*(10, 10), *(10, 2)))])\n" + " LogicalValues(tuples=[[{ 0 }]])\n", + "SELECT \"foo\"(10 * 10 + 10 * 2)\n" + "FROM (VALUES (0)) AS \"t\" (\"ZERO\")")) + .build().stream().map(x -> new Object[] { x.sparkSql, x.explain, x.trinoSql }).iterator(); + } + + private static class Spark2TrinoDataProvider { + private final String sparkSql; + private final String explain; + private final String trinoSql; + + public Spark2TrinoDataProvider(@Language("SQL") String sparkSql, String explain, @Language("SQL") String trinoSql) { + this.sparkSql = sparkSql; + this.explain = explain; + this.trinoSql = trinoSql; + } + } + + //TODO: Add unsupported SQL tests + + @Test(dataProvider = "support") + public void testSupport(String trinoSql, String expectedRelString, String expectedSql) { + RelNode relNode = sparkToRelConverter.convertSql(trinoSql); + assertEquals(expectedRelString, RelOptUtil.toString(relNode)); + + RelToTrinoConverter relToTrinoConverter = new RelToTrinoConverter(); + // Convert rel node back to Sql + String expandedSql = relToTrinoConverter.convert(relNode); + assertEquals(expectedSql, expandedSql); + } + +} diff --git a/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/ToRelTestUtils.java b/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/ToRelTestUtils.java new file mode 100644 index 000000000..f34e1f514 --- /dev/null +++ b/coral-spark/src/test/java/com/linkedin/coral/spark/spark2rel/ToRelTestUtils.java @@ -0,0 +1,74 @@ +/** + * Copyright 2023-2024 LinkedIn Corporation. All rights reserved. + * Licensed under the BSD-2 Clause license. + * See LICENSE in the project root for license information. + */ +package com.linkedin.coral.spark.spark2rel; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.util.UUID; + +import org.apache.commons.io.FileUtils; +import org.apache.hadoop.hive.conf.HiveConf; +import org.apache.hadoop.hive.metastore.api.MetaException; +import org.apache.hadoop.hive.ql.CommandNeedRetryException; +import org.apache.hadoop.hive.ql.Driver; +import org.apache.hadoop.hive.ql.metadata.Hive; +import org.apache.hadoop.hive.ql.metadata.HiveException; +import org.apache.hadoop.hive.ql.processors.CommandProcessorResponse; +import org.apache.hadoop.hive.ql.session.SessionState; + +import com.linkedin.coral.common.HiveMscAdapter; + + +public class ToRelTestUtils { + public static final String CORAL_FROM_TRINO_TEST_DIR = "coral.spark.test.dir"; + + private static HiveMscAdapter hiveMetastoreClient; + public static SparkToRelConverter sparkToRelConverter; + + static void run(Driver driver, String sql) { + while (true) { + try { + CommandProcessorResponse result = driver.run(sql); + if (result.getException() != null) { + throw new RuntimeException("Execution failed for: " + sql, result.getException()); + } + } catch (CommandNeedRetryException e) { + continue; + } + break; + } + } + + public static void initializeViews(HiveConf conf) throws HiveException, MetaException, IOException { + String testDir = conf.get(CORAL_FROM_TRINO_TEST_DIR) + UUID.randomUUID(); + System.out.println("Test Workspace: " + testDir); + FileUtils.deleteDirectory(new File(testDir)); + SessionState.start(conf); + Driver driver = new Driver(conf); + hiveMetastoreClient = new HiveMscAdapter(Hive.get(conf).getMSC()); + sparkToRelConverter = new SparkToRelConverter(hiveMetastoreClient); + + // Views and tables used in TrinoToTrinoConverterTest + run(driver, "CREATE DATABASE IF NOT EXISTS default"); + run(driver, "CREATE TABLE IF NOT EXISTS default.foo(show int, a int, b int, x date, y date)"); + run(driver, "CREATE TABLE IF NOT EXISTS default.my_table(x array, y array>, z int)"); + run(driver, "CREATE TABLE IF NOT EXISTS default.a(b int, id int, x int)"); + run(driver, "CREATE TABLE IF NOT EXISTS default.b(foobar int, id int, y int)"); + } + + public static HiveConf loadResourceHiveConf() { + InputStream hiveConfStream = ToRelTestUtils.class.getClassLoader().getResourceAsStream("hive.xml"); + HiveConf hiveConf = new HiveConf(); + hiveConf.set(CORAL_FROM_TRINO_TEST_DIR, + System.getProperty("java.io.tmpdir") + "/coral/trino/" + UUID.randomUUID().toString()); + hiveConf.addResource(hiveConfStream); + hiveConf.set("mapreduce.framework.name", "local-trino"); + hiveConf.set("_hive.hdfs.session.path", "/tmp/coral/trino"); + hiveConf.set("_hive.local.session.path", "/tmp/coral/trino"); + return hiveConf; + } +}