Skip to content

Commit

Permalink
Merge branch 'main' into feature/recon_documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
sundarshankar89 authored Jun 8, 2024
2 parents ab1342b + 0783f10 commit 3e25e7a
Show file tree
Hide file tree
Showing 46 changed files with 1,937 additions and 395 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3423,7 +3423,7 @@ id_
| non_reserved_words
| object_type_plural
| data_type
| builtin_function
| builtin_function_name
| unary_or_binary_builtin_function
| binary_builtin_function
| binary_or_ternary_builtin_function
Expand Down Expand Up @@ -3560,7 +3560,7 @@ non_reserved_words
| MODE
;

builtin_function
builtin_function_name
// If there is a lexer entry for a function we also need to add the token here
// as it otherwise will not be picked up by the id_ rule (See also derived rule below)
: SUM
Expand Down Expand Up @@ -3701,7 +3701,6 @@ expr
| cast_expr #exprCast
| expr COLON_COLON data_type #exprAscribe
| json_literal #exprJsonLit
| trim_expression #exprTrim
| function_call #exprFuncCall
// Probably wrong
| subquery #exprSubquery
Expand Down Expand Up @@ -3794,14 +3793,12 @@ data_type
;

primitive_expression
: DEFAULT //?
| NULL_
| id_ (DOT id_)* // json field access
| full_column_name
| literal
| BOTH_Q
| ARRAY_Q
| OBJECT_Q
: DEFAULT # primExprDefault//?
| full_column_name # primExprColumn
| literal # primExprLiteral
| BOTH_Q # primExprBoth
| ARRAY_Q # primExprArray
| OBJECT_Q # primExprObject
//| json_literal
//| arr_literal
;
Expand Down Expand Up @@ -3829,19 +3826,23 @@ over_clause
;

function_call
: unary_or_binary_builtin_function L_PAREN expr (COMMA expr)* R_PAREN
| binary_builtin_function L_PAREN expr COMMA expr R_PAREN
| binary_or_ternary_builtin_function L_PAREN expr COMMA expr (COMMA expr)* R_PAREN
| ternary_builtin_function L_PAREN expr COMMA expr COMMA expr R_PAREN
: builtin_function
| standard_function
| ranking_windowed_function
| aggregate_function
// | aggregate_windowed_function
| object_name L_PAREN expr_list? R_PAREN
| object_name L_PAREN param_assoc_list R_PAREN
| list_function L_PAREN expr_list R_PAREN
| to_date = ( TO_DATE | DATE) L_PAREN expr R_PAREN
| length = ( LENGTH | LEN) L_PAREN expr R_PAREN
| TO_BOOLEAN L_PAREN expr R_PAREN
;

builtin_function
: trim = (TRIM | LTRIM | RTRIM) L_PAREN expr (COMMA string)? R_PAREN #builtinTrim
// : unary_or_binary_builtin_function L_PAREN expr (COMMA expr)* R_PAREN
// | binary_builtin_function L_PAREN expr COMMA expr R_PAREN
// | binary_or_ternary_builtin_function L_PAREN expr COMMA expr (COMMA expr)* R_PAREN
// | ternary_builtin_function L_PAREN expr COMMA expr COMMA expr R_PAREN
;

standard_function
: id_ L_PAREN expr_list? R_PAREN
;

param_assoc_list
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3687,11 +3687,11 @@ subquery
;

withExpression
: WITH ctes += commonTableExpression (COMMA ctes += commonTableExpression)*
: WITH commonTableExpression (COMMA commonTableExpression)*
;

commonTableExpression
: expressionName = id (LPAREN columns = columnNameList RPAREN)? AS LPAREN cteQuery = selectStatement RPAREN
: id (LPAREN columnNameList RPAREN)? AS LPAREN selectStatement RPAREN
;

updateElem
Expand Down Expand Up @@ -3863,7 +3863,6 @@ selectListElem
: asterisk
| LOCAL_ID op=(PE | ME | SE | DE | MEA | AND_ASSIGN | XOR_ASSIGN | OR_ASSIGN | EQ) expression
| expressionElem
| udtElem // TODO: May not be needed as expressionElem could handle this?
;

tableSources
Expand Down Expand Up @@ -4294,7 +4293,7 @@ insertColumnId
;

columnNameList
: col += id (COMMA col += id)*
: id (COMMA id)*
;

cursorName
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ object FunctionDefinition {

abstract class FunctionBuilder {

private val functionDefinitionPf: PartialFunction[String, FunctionDefinition] = {
protected val commonFunctionsPf: PartialFunction[String, FunctionDefinition] = {
case "ABS" => FunctionDefinition.standard(1)
case "ACOS" => FunctionDefinition.standard(1)
case "APP_NAME" => FunctionDefinition.standard(0)
Expand Down Expand Up @@ -177,6 +177,7 @@ abstract class FunctionBuilder {
case "LEFT" => FunctionDefinition.standard(2)
case "LEN" => FunctionDefinition.standard(1)
case "LISTAGG" => FunctionDefinition.standard(1, 2)
case "LN" => FunctionDefinition.standard(1)
case "LOG" => FunctionDefinition.standard(1, 2)
case "LOG10" => FunctionDefinition.standard(1)
case "LOGINPROPERTY" => FunctionDefinition.standard(2)
Expand Down Expand Up @@ -279,7 +280,7 @@ abstract class FunctionBuilder {
}

def functionDefinition(name: String): Option[FunctionDefinition] =
functionDefinitionPf.lift(name.toUpperCase())
commonFunctionsPf.lift(name.toUpperCase())

def functionType(name: String): FunctionType = {
functionDefinition(name).map(_.functionType).getOrElse(UnknownFunction)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ class SnowflakeDDLBuilder
extends SnowflakeParserBaseVisitor[ir.Catalog]
with ParserCommon[ir.Catalog]
with IncompleteParser[ir.Catalog] {

private val expressionBuilder = new SnowflakeExpressionBuilder

override protected def wrapUnresolvedInput(unparsedInput: String): ir.Catalog = ir.UnresolvedCatalog(unparsedInput)

private def extractString(ctx: StrContext): String = {
Expand Down Expand Up @@ -37,7 +40,7 @@ class SnowflakeDDLBuilder
name = ctx.arg_name().getText,
dataType = DataTypeBuilder.buildDataType(ctx.arg_data_type().id_().data_type()),
defaultValue = Option(ctx.arg_default_value_clause())
.map(_.expr().accept(new SnowflakeExpressionBuilder(new SnowflakeFunctionBuilder))))
.map(_.expr().accept(expressionBuilder)))
}

private def buildFunctionBody(ctx: Function_definitionContext): String = (ctx match {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
package com.databricks.labs.remorph.parsers.snowflake

import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser._
import com.databricks.labs.remorph.parsers.{FunctionBuilder, IncompleteParser, ParserCommon, intermediate => ir}
import com.databricks.labs.remorph.parsers.{IncompleteParser, ParserCommon, intermediate => ir}
import org.antlr.v4.runtime.Token

import scala.collection.JavaConverters._
class SnowflakeExpressionBuilder(functionBuilder: FunctionBuilder)
class SnowflakeExpressionBuilder()
extends SnowflakeParserBaseVisitor[ir.Expression]
with ParserCommon[ir.Expression]
with IncompleteParser[ir.Expression] {

private val functionBuilder = new SnowflakeFunctionBuilder

protected override def wrapUnresolvedInput(unparsedInput: String): ir.UnresolvedExpression =
ir.UnresolvedExpression(unparsedInput)
override def visitSelect_list_elem(ctx: Select_list_elemContext): ir.Expression = {
Expand Down Expand Up @@ -38,13 +40,9 @@ class SnowflakeExpressionBuilder(functionBuilder: FunctionBuilder)
ir.Column(ctx.id_(0).getText)
}

override def visitPrimitive_expression(ctx: Primitive_expressionContext): ir.Expression = {
if (!ctx.id_().isEmpty) {
val columnName = ctx.id_().asScala.map(_.getText).mkString(".")
ir.Column(columnName)
} else {
super.visitPrimitive_expression(ctx)
}
override def visitPrimExprColumn(ctx: PrimExprColumnContext): ir.Expression = {
val columnName = ctx.full_column_name().id_().asScala.map(_.getText).mkString(".")
ir.Column(columnName)
}

override def visitOrder_item(ctx: Order_itemContext): ir.Expression = {
Expand Down Expand Up @@ -289,13 +287,21 @@ class SnowflakeExpressionBuilder(functionBuilder: FunctionBuilder)
}
}

override def visitStandard_function(ctx: Standard_functionContext): ir.Expression = {
val functionName = ctx.id_().getText
val arguments = Option(ctx.expr_list()).map(_.expr().asScala.map(_.accept(this))).getOrElse(Seq())
functionBuilder.buildFunction(functionName, arguments)
}

// aggregate_function

override def visitAggFuncExprList(ctx: AggFuncExprListContext): ir.Expression = {
val param = ctx.expr_list().expr(0).accept(this)
buildBuiltinFunction(ctx.id_().builtin_function(), param)
buildBuiltinFunction(ctx.id_().builtin_function_name(), param)
}

override def visitAggFuncStar(ctx: AggFuncStarContext): ir.Expression = {
buildBuiltinFunction(ctx.id_().builtin_function(), ir.Star(None))
buildBuiltinFunction(ctx.id_().builtin_function_name(), ir.Star(None))
}

override def visitAggFuncList(ctx: AggFuncListContext): ir.Expression = {
Expand All @@ -306,7 +312,9 @@ class SnowflakeExpressionBuilder(functionBuilder: FunctionBuilder)
case ARRAY_AGG => functionBuilder.buildFunction("ARRAYAGG", Seq(param))
}
}
private def buildBuiltinFunction(ctx: Builtin_functionContext, param: ir.Expression): ir.Expression =
// end aggregate_function

private def buildBuiltinFunction(ctx: Builtin_function_nameContext, param: ir.Expression): ir.Expression =
Option(ctx)
.collect {
case c if c.AVG() != null => ir.Avg(param)
Expand All @@ -316,6 +324,11 @@ class SnowflakeExpressionBuilder(functionBuilder: FunctionBuilder)
}
.getOrElse(param)

override def visitBuiltinTrim(ctx: BuiltinTrimContext): ir.Expression = {
val expression = ctx.expr().accept(this)
val characters = Option(ctx.string()).map(_.accept(this)).toList
functionBuilder.buildFunction(ctx.trim.getText, expression :: characters)
}
override def visitCase_expression(ctx: Case_expressionContext): ir.Expression = {
val exprs = ctx.expr().asScala
val otherwise = Option(ctx.ELSE()).flatMap(els => exprs.find(occursBefore(els, _)).map(_.accept(this)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,150 @@ import com.databricks.labs.remorph.parsers.{ConversionStrategy, FunctionBuilder,
class SnowflakeFunctionBuilder extends FunctionBuilder {

private val SnowflakeFunctionDefinitionPf: PartialFunction[String, FunctionDefinition] = {
case "IFNULL" => FunctionDefinition.standard(2)
case "ADD_MONTHS" => FunctionDefinition.standard(2)
case "ANY_VALUE" => FunctionDefinition.standard(1)
case "APPROX_COUNT_DISTINCT" => FunctionDefinition.standard(1)
case "APPROX_PERCENTILE" => FunctionDefinition.standard(2)
case "APPROX_PERCENTILE_CONT" => FunctionDefinition.standard(1)
case "APPROX_PERCENTILE_DISC" => FunctionDefinition.standard(1)
case "APPROX_TOP_K" => FunctionDefinition.standard(1, 3)
// case "ARRAY" => ???
case "ARRAYS_OVERLAP" => FunctionDefinition.standard(2)
case "ARRAY_AGG" => FunctionDefinition.standard(1)
case "ARRAY_APPEND" => FunctionDefinition.standard(2)
case "ARRAY_CAT" => FunctionDefinition.standard(2)
case "ARRAY_COMPACT" => FunctionDefinition.standard(1)
case "ARRAY_CONSTRUCT" => FunctionDefinition.standard(0, Int.MaxValue)
case "ARRAY_CONSTRUCT_COMPACT" => FunctionDefinition.standard(0, Int.MaxValue)
case "ARRAY_CONTAINS" => FunctionDefinition.standard(2)
case "ARRAY_DISTINCT" => FunctionDefinition.standard(1)
case "ARRAY_EXCEPT" => FunctionDefinition.standard(2)
case "ARRAY_INSERT" => FunctionDefinition.standard(3)
case "ARRAY_INTERSECTION" => FunctionDefinition.standard(2)
case "ARRAY_POSITION" => FunctionDefinition.standard(2)
case "ARRAY_PREPEND" => FunctionDefinition.standard(2)
case "ARRAY_REMOVE" => FunctionDefinition.standard(2)
case "ARRAY_SIZE" => FunctionDefinition.standard(1)
case "ARRAY_SLICE" => FunctionDefinition.standard(3)
case "ARRAY_TO_STRING" => FunctionDefinition.standard(2)
case "ATAN2" => FunctionDefinition.standard(2)
case "BASE64_DECODE_STRING" => FunctionDefinition.standard(1, 2)
case "BASE64_ENCODE" => FunctionDefinition.standard(1, 3)
case "BITOR_AGG" => FunctionDefinition.standard(1)
case "BOOLAND_AGG" => FunctionDefinition.standard(1)
case "CEIL" => FunctionDefinition.standard(1, 2)
case "COLLATE" => FunctionDefinition.standard(2)
case "COLLATION" => FunctionDefinition.standard(1)
case "CONTAINS" => FunctionDefinition.standard(2)
case "CONVERT_TIMEZONE" => FunctionDefinition.standard(2, 3)
case "CORR" => FunctionDefinition.standard(2)
case "COUNT_IF" => FunctionDefinition.standard(1)
case "CURRENT_DATABASE" => FunctionDefinition.standard(0)
case "CURRENT_TIMESTAMP" => FunctionDefinition.standard(0, 1)
case "DATEDIFF" => FunctionDefinition.standard(3)
case "DATE_FROM_PARTS" => FunctionDefinition.standard(3)
case "DATE_PART" => FunctionDefinition.standard(2)
case "DATE_TRUNC" => FunctionDefinition.standard(2)
case "DAYNAME" => FunctionDefinition.standard(1)
case "DECODE" => FunctionDefinition.standard(3, Int.MaxValue)
case "DIV0" => FunctionDefinition.standard(2)
case "DIV0NULL" => FunctionDefinition.standard(2)
case "ENDSWITH" => FunctionDefinition.standard(2)
case "EQUAL_NULL" => FunctionDefinition.standard(2)
// case "FLATTEN" => ???
case "GET" => FunctionDefinition.standard(2)
case "IFNULL" => FunctionDefinition.standard(1, 2)
case "INITCAP" => FunctionDefinition.standard(1, 2)
case "ISNULL" => FunctionDefinition.standard(1)
case "IS_INTEGER" => FunctionDefinition.standard(1)
case "JSON_EXTRACT_PATH_TEXT" => FunctionDefinition.standard(2)
case "LAST_DAY" => FunctionDefinition.standard(1, 2)
case "LPAD" => FunctionDefinition.standard(2, 3)
case "LTRIM" => FunctionDefinition.standard(1, 2)
case "MEDIAN" => FunctionDefinition.standard(1)
case "MOD" => FunctionDefinition.standard(2)
case "MODE" => FunctionDefinition.standard(1)
case "MONTHNAME" => FunctionDefinition.standard(1)
// case "MONTH_NAME" => ???
case "NEXT_DAY" => FunctionDefinition.standard(2)
// case "NTH_VALUE" =>
case "NULLIFZERO" => FunctionDefinition.standard(1)
case "NVL" => FunctionDefinition.standard(2)
case "NVL2" => FunctionDefinition.standard(3)
case "OBJECT_CONSTRUCT" => FunctionDefinition.standard(1, Int.MaxValue)
case "OBJECT_KEYS" => FunctionDefinition.standard(1)
case "PARSE_JSON" => FunctionDefinition.standard(1)
case "PARSE_URL" => FunctionDefinition.standard(1, 2)
case "POSITION" => FunctionDefinition.standard(2, 3)
case "RANDOM" => FunctionDefinition.standard(0, 1)
case "RANK" => FunctionDefinition.standard(0)
case "REGEXP_COUNT" => FunctionDefinition.standard(2, 4)
case "REGEXP_INSTR" => FunctionDefinition.standard(2, 7)
case "REGEXP_LIKE" => FunctionDefinition.standard(2, 3)
case "REGEXP_REPLACE" => FunctionDefinition.standard(2, 6)
case "REGEXP_SUBSTR" => FunctionDefinition.standard(2, 6)
case "REGR_INTERCEPT" => FunctionDefinition.standard(2)
case "REGR_R2" => FunctionDefinition.standard(2)
case "REGR_SLOPE" => FunctionDefinition.standard(2)
case "REPEAT" => FunctionDefinition.standard(2)
case "ROUND" => FunctionDefinition.standard(1, 3)
case "RPAD" => FunctionDefinition.standard(2, 3)
case "RTRIM" => FunctionDefinition.standard(1, 2)
case "SPLIT_PART" => FunctionDefinition.standard(3)
case "SQUARE" => FunctionDefinition.standard(1)
case "STARTSWITH" => FunctionDefinition.standard(2)
case "STDDEV" => FunctionDefinition.standard(1)
case "STDDEV_POP" => FunctionDefinition.standard(1)
case "STDDEV_SAMP" => FunctionDefinition.standard(1)
case "STRIP_NULL_VALUE" => FunctionDefinition.standard(1)
case "STRTOK" => FunctionDefinition.standard(1, 3)
case "STRTOK_TO_ARRAY" => FunctionDefinition.standard(1, 2)
case "SYSDATE" => FunctionDefinition.standard(0)
case "TIMEADD" => FunctionDefinition.standard(3)
case "TIMESTAMPADD" => FunctionDefinition.standard(3)
case "TIMESTAMPDIFF" => FunctionDefinition.standard(3)
case "TIMESTAMP_FROM_PARTS" => FunctionDefinition.standard(2, 8)
case "TO_ARRAY" => FunctionDefinition.standard(1)
case "TO_BOOLEAN" => FunctionDefinition.standard(1)
case "TO_CHAR" => FunctionDefinition.standard(1, 2)
case "TO_DATE" => FunctionDefinition.standard(1, 2)
case "TO_DECIMAL" => FunctionDefinition.standard(1, 4)
case "TO_DOUBLE" => FunctionDefinition.standard(1, 2)
case "TO_JSON" => FunctionDefinition.standard(1)
case "TO_NUMBER" => FunctionDefinition.standard(1, 4)
case "TO_NUMERIC" => FunctionDefinition.standard(1, 4)
case "TO_OBJECT" => FunctionDefinition.standard(1)
case "TO_TIME" => FunctionDefinition.standard(1, 2)
case "TO_TIMESTAMP" => FunctionDefinition.standard(1, 2)
case "TO_TIMESTAMP_LTZ" => FunctionDefinition.standard(1, 2)
case "TO_TIMESTAMP_NTZ" => FunctionDefinition.standard(1, 2)
case "TO_TIMESTAMP_TZ" => FunctionDefinition.standard(1, 2)
case "TO_VARCHAR" => FunctionDefinition.standard(1, 2)
case "TRIM" => FunctionDefinition.standard(1, 2)
case "TRUNC" => FunctionDefinition.standard(2)
case "TRY_BASE64_DECODE_STRING" => FunctionDefinition.standard(1, 2)
case "TRY_PARSE_JSON" => FunctionDefinition.standard(1)
case "TRY_TO_BINARY" => FunctionDefinition.standard(1, 2)
case "TRY_TO_BOOLEAN" => FunctionDefinition.standard(1)
case "TRY_TO_DATE" => FunctionDefinition.standard(1, 2)
case "TRY_TO_DECIMAL" => FunctionDefinition.standard(1, 4)
case "TRY_TO_DOUBLE" => FunctionDefinition.standard(1, 2)
case "TRY_TO_NUMBER" => FunctionDefinition.standard(1, 4)
case "TRY_TO_NUMERIC" => FunctionDefinition.standard(1, 4)
case "TRY_TO_TIME" => FunctionDefinition.standard(1, 2)
case "TRY_TO_TIMESTAMP" => FunctionDefinition.standard(1, 2)
case "TRY_TO_TIMESTAMP_LTZ" => FunctionDefinition.standard(1, 2)
case "TRY_TO_TIMESTAMP_NTZ" => FunctionDefinition.standard(1, 2)
case "TRY_TO_TIMESTAMP_TZ" => FunctionDefinition.standard(1, 2)
case "TYPEOF" => FunctionDefinition.standard(1)
case "UUID_STRING" => FunctionDefinition.standard(0, 2)
case "ZEROIFNULL" => FunctionDefinition.standard(1)

}

override def functionDefinition(name: String): Option[FunctionDefinition] =
// If not found, check common functions
SnowflakeFunctionDefinitionPf.lift(name.toUpperCase()).orElse(super.functionDefinition(name))
SnowflakeFunctionDefinitionPf.orElse(commonFunctionsPf).lift(name.toUpperCase())

def applyConversionStrategy(
functionArity: FunctionDefinition,
Expand Down
Loading

0 comments on commit 3e25e7a

Please sign in to comment.