Skip to content

Commit

Permalink
Support function translation to Databricks SQL in TSql and Snowflake (#…
Browse files Browse the repository at this point in the history
…414)

Some functions must be translated from TSQL or Snowflake versions into
the equivalent IR for Databricks SQL. In some cases the function must be
translated in one dialect, say TSql but is equivalent in another, say
Snowflake.

Here we upgrade the FunctionBuilder system to be dialect aware, and
provide a ConversionStrategy system that allows for any type of
conversion from a simple name translation or more complicated specific
IR representations when there is no equivalent.

For example the TSQL code
```tsql
SELECT ISNULL(x, 0)
```

Should translate to:

```sql
SELECT IFNULL(x, 0)
```

In Databricks SQL, but in Snowflake SQL:

```snowflake
SELECT ISNULL(col)
```

Is directly equivalent to Databricks SQL and needs no conversion.

---------

Co-authored-by: Valentin Kasas <[email protected]>
  • Loading branch information
jimidle and vil1 authored Jun 4, 2024
1 parent d628677 commit eac0207
Show file tree
Hide file tree
Showing 14 changed files with 768 additions and 588 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.databricks.labs.remorph.parsers
import com.databricks.labs.remorph.parsers.{intermediate => ir}

import java.util.Locale

trait ConversionStrategy {
def convert(irName: String, args: Seq[ir.Expression]): ir.Expression
}

trait StringConverter {
// Preserves case if the original name was all lower case. Otherwise, converts to upper case.
// All bets are off if the original name was mixed case, but that is rarely seen in SQL and we are
// just making reasonable efforts here.
def convertString(irName: String, newName: String): String = {
if (irName.forall(_.isLower)) newName.toLowerCase(Locale.ROOT) else newName
}
}

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.databricks.labs.remorph.parsers.snowflake

import com.databricks.labs.remorph.parsers.{IncompleteParser, ParserCommon, intermediate => ir}
import SnowflakeParser.{StringContext => StrContext, _}
import com.databricks.labs.remorph.parsers.intermediate.AddColumn
import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser.{StringContext => StrContext, _}
import com.databricks.labs.remorph.parsers.{IncompleteParser, ParserCommon, intermediate => ir}

import scala.collection.JavaConverters._
class SnowflakeDDLBuilder
Expand Down Expand Up @@ -36,7 +36,8 @@ class SnowflakeDDLBuilder
ir.FunctionParameter(
name = ctx.arg_name().getText,
dataType = DataTypeBuilder.buildDataType(ctx.arg_data_type().id_().data_type()),
defaultValue = Option(ctx.arg_default_value_clause()).map(_.expr().accept(new SnowflakeExpressionBuilder)))
defaultValue = Option(ctx.arg_default_value_clause())
.map(_.expr().accept(new SnowflakeExpressionBuilder(new SnowflakeFunctionBuilder))))
}

private def buildFunctionBody(ctx: Function_definitionContext): String = (ctx match {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package com.databricks.labs.remorph.parsers.snowflake

import com.databricks.labs.remorph.parsers.{FunctionBuilder, IncompleteParser, ParserCommon, intermediate => ir}
import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser._
import com.databricks.labs.remorph.parsers.{FunctionBuilder, IncompleteParser, ParserCommon, intermediate => ir}
import org.antlr.v4.runtime.Token

import scala.collection.JavaConverters._
class SnowflakeExpressionBuilder
class SnowflakeExpressionBuilder(functionBuilder: FunctionBuilder)
extends SnowflakeParserBaseVisitor[ir.Expression]
with ParserCommon[ir.Expression]
with IncompleteParser[ir.Expression] {
Expand Down Expand Up @@ -302,10 +302,9 @@ class SnowflakeExpressionBuilder
val param = ctx.expr().accept(this)
val separator = Option(ctx.string()).map(s => ir.Literal(string = Some(removeQuotes(s.getText))))
ctx.op.getType match {
case LISTAGG => FunctionBuilder.buildFunction("LISTAGG", param +: separator.toSeq)
case ARRAY_AGG => FunctionBuilder.buildFunction("ARRAYAGG", Seq(param))
case LISTAGG => functionBuilder.buildFunction("LISTAGG", param +: separator.toSeq)
case ARRAY_AGG => functionBuilder.buildFunction("ARRAYAGG", Seq(param))
}

}
private def buildBuiltinFunction(ctx: Builtin_functionContext, param: ir.Expression): ir.Expression =
Option(ctx)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package com.databricks.labs.remorph.parsers.snowflake

import com.databricks.labs.remorph.parsers.{ConversionStrategy, FunctionBuilder, FunctionDefinition, intermediate => ir}

class SnowflakeFunctionBuilder extends FunctionBuilder {

private val SnowflakeFunctionDefinitionPf: PartialFunction[String, FunctionDefinition] = {
case "IFNULL" => FunctionDefinition.standard(2)
case "ISNULL" => FunctionDefinition.standard(1)
}

override def functionDefinition(name: String): Option[FunctionDefinition] =
// If not found, check common functions
SnowflakeFunctionDefinitionPf.lift(name.toUpperCase()).orElse(super.functionDefinition(name))

def applyConversionStrategy(
functionArity: FunctionDefinition,
args: Seq[ir.Expression],
irName: String): ir.Expression = {
functionArity.conversionStrategy match {
case Some(strategy) => strategy.convert(irName, args)
case _ => ir.CallFunction(irName, args)
}
}
}

object SnowflakeFunctionConverters {

object FunctionRename extends ConversionStrategy {
override def convert(irName: String, args: Seq[ir.Expression]): ir.Expression = {
irName.toUpperCase() match {
case _ => ir.CallFunction(irName, args)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.databricks.labs.remorph.parsers.snowflake

import com.databricks.labs.remorph.parsers.{IncompleteParser, intermediate => ir}
import com.databricks.labs.remorph.parsers.snowflake.SnowflakeParser._
import com.databricks.labs.remorph.parsers.{IncompleteParser, intermediate => ir}
import org.antlr.v4.runtime.ParserRuleContext

import scala.collection.JavaConverters._
Expand All @@ -24,7 +24,7 @@ class SnowflakeRelationBuilder extends SnowflakeParserBaseVisitor[ir.Relation] w
c.select_top_clause().select_list_top().all_distinct(),
c.select_top_clause().select_list_top().select_list().select_list_elem().asScala)
}
val expressions = selectListElements.map(_.accept(new SnowflakeExpressionBuilder))
val expressions = selectListElements.map(_.accept(new SnowflakeExpressionBuilder(new SnowflakeFunctionBuilder)))
ir.Project(buildTop(top, buildDistinct(allOrDistinct, relation, expressions)), expressions)

}
Expand Down Expand Up @@ -74,7 +74,7 @@ class SnowflakeRelationBuilder extends SnowflakeParserBaseVisitor[ir.Relation] w

private def buildFilter[A](ctx: A, conditionRule: A => ParserRuleContext, input: ir.Relation): ir.Relation =
Option(ctx).fold(input) { c =>
ir.Filter(input, conditionRule(c).accept(new SnowflakeExpressionBuilder))
ir.Filter(input, conditionRule(c).accept(new SnowflakeExpressionBuilder(new SnowflakeFunctionBuilder)))
}
private def buildHaving(ctx: Having_clauseContext, input: ir.Relation): ir.Relation =
buildFilter[Having_clauseContext](ctx, _.search_condition(), input)
Expand All @@ -87,7 +87,10 @@ class SnowflakeRelationBuilder extends SnowflakeParserBaseVisitor[ir.Relation] w
private def buildGroupBy(ctx: Group_by_clauseContext, input: ir.Relation): ir.Relation = {
Option(ctx).fold(input) { c =>
val groupingExpressions =
c.group_by_list().group_by_elem().asScala.map(_.accept(new SnowflakeExpressionBuilder))
c.group_by_list()
.group_by_elem()
.asScala
.map(_.accept(new SnowflakeExpressionBuilder(new SnowflakeFunctionBuilder)))
val aggregate =
ir.Aggregate(input = input, group_type = ir.GroupBy, grouping_expressions = groupingExpressions, pivot = None)
buildHaving(c.having_clause(), aggregate)
Expand All @@ -97,7 +100,7 @@ class SnowflakeRelationBuilder extends SnowflakeParserBaseVisitor[ir.Relation] w
private def buildOrderBy(ctx: Order_by_clauseContext, input: ir.Relation): ir.Relation = {
Option(ctx).fold(input) { c =>
val sortOrders = c.order_item().asScala.map { orderItem =>
val expression = orderItem.accept(new SnowflakeExpressionBuilder)
val expression = orderItem.accept(new SnowflakeExpressionBuilder(new SnowflakeFunctionBuilder))
if (orderItem.DESC() == null) {
if (orderItem.NULLS() != null && orderItem.FIRST() != null) {
ir.SortOrder(expression, ir.AscendingSortDirection, ir.SortNullsFirst)
Expand All @@ -121,7 +124,7 @@ class SnowflakeRelationBuilder extends SnowflakeParserBaseVisitor[ir.Relation] w
}

override def visitValues_table_body(ctx: Values_table_bodyContext): ir.Relation = {
val expressionBuilder = new SnowflakeExpressionBuilder
val expressionBuilder = new SnowflakeExpressionBuilder(new SnowflakeFunctionBuilder)
val expressions =
ctx.expr_list_in_parentheses().asScala.map(l => expressionBuilder.visitSeq(l.expr_list().expr().asScala))
ir.Values(expressions)
Expand All @@ -144,9 +147,10 @@ class SnowflakeRelationBuilder extends SnowflakeParserBaseVisitor[ir.Relation] w
}

private def buildPivot(ctx: Pivot_unpivotContext, relation: ir.Relation): ir.Relation = {
val pivotValues: Seq[ir.Literal] = ctx.literal().asScala.map(_.accept(new SnowflakeExpressionBuilder)).collect {
case lit: ir.Literal => lit
}
val pivotValues: Seq[ir.Literal] =
ctx.literal().asScala.map(_.accept(new SnowflakeExpressionBuilder(new SnowflakeFunctionBuilder))).collect {
case lit: ir.Literal => lit
}
val pivotColumn = ir.Column(ctx.id_(2).getText)
val aggregateFunction = translateAggregateFunction(ctx.id_(0), ctx.id_(1))
ir.Aggregate(
Expand All @@ -157,7 +161,11 @@ class SnowflakeRelationBuilder extends SnowflakeParserBaseVisitor[ir.Relation] w
}

private def buildUnpivot(ctx: Pivot_unpivotContext, relation: ir.Relation): ir.Relation = {
val unpivotColumns = ctx.column_list().column_name().asScala.map(_.accept(new SnowflakeExpressionBuilder))
val unpivotColumns = ctx
.column_list()
.column_name()
.asScala
.map(_.accept(new SnowflakeExpressionBuilder(new SnowflakeFunctionBuilder)))
val variableColumnName = ctx.id_(0).getText
val valueColumnName = ctx.column_name().id_(0).getText
ir.Unpivot(
Expand Down Expand Up @@ -210,7 +218,11 @@ class SnowflakeRelationBuilder extends SnowflakeParserBaseVisitor[ir.Relation] w

override def visitCommon_table_expression(ctx: Common_table_expressionContext): ir.Relation = {
val tableName = ctx.id_().getText
val columns = ctx.column_list().column_name().asScala.map(_.accept(new SnowflakeExpressionBuilder))
val columns = ctx
.column_list()
.column_name()
.asScala
.map(_.accept(new SnowflakeExpressionBuilder(new SnowflakeFunctionBuilder)))
val query = ctx.select_statement().accept(this)
ir.CTEDefinition(tableName, columns, query)
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
package com.databricks.labs.remorph.parsers.tsql

import com.databricks.labs.remorph.parsers.tsql.TSqlParser._
import com.databricks.labs.remorph.parsers.{FunctionBuilder, ParserCommon, StandardFunction, UnknownFunction, XmlFunction, intermediate => ir}
import com.databricks.labs.remorph.parsers.{FunctionBuilder, ParserCommon, XmlFunction, intermediate => ir}
import org.antlr.v4.runtime.Token
import org.antlr.v4.runtime.tree.{TerminalNode, Trees}

import scala.collection.JavaConverters._

class TSqlExpressionBuilder extends TSqlParserBaseVisitor[ir.Expression] with ParserCommon[ir.Expression] {
class TSqlExpressionBuilder(functionBuilder: FunctionBuilder)
extends TSqlParserBaseVisitor[ir.Expression]
with ParserCommon[ir.Expression] {

override def visitSelectListElem(ctx: TSqlParser.SelectListElemContext): ir.Expression = {
ctx match {
Expand Down Expand Up @@ -144,10 +146,9 @@ class TSqlExpressionBuilder extends TSqlParserBaseVisitor[ir.Expression] with Pa
case (c1: ir.Column, c2: ir.Column) =>
ir.Column(c1.name + "." + c2.name)
case (_: ir.Column, c2: ir.CallFunction) =>
FunctionBuilder.functionType(c2.function_name) match {
case StandardFunction => ir.Dot(left, right)
functionBuilder.functionType(c2.function_name) match {
case XmlFunction => ir.XmlFunction(c2, left)
case UnknownFunction => ir.Dot(left, right)
case _ => ir.Dot(left, right)
}
// Other cases
case _ => ir.Dot(left, right)
Expand Down Expand Up @@ -177,7 +178,7 @@ class TSqlExpressionBuilder extends TSqlParserBaseVisitor[ir.Expression] with Pa
override def visitExprDollar(ctx: ExprDollarContext): ir.Expression = ir.DollarAction()

override def visitExprFuncVal(ctx: ExprFuncValContext): ir.Expression = {
FunctionBuilder.buildFunction(ctx.getText, Seq.empty)
functionBuilder.buildFunction(ctx.getText, Seq.empty)
}

override def visitExprCollate(ctx: ExprCollateContext): ir.Expression =
Expand Down Expand Up @@ -285,7 +286,7 @@ class TSqlExpressionBuilder extends TSqlParserBaseVisitor[ir.Expression] with Pa
override def visitStandardFunction(ctx: StandardFunctionContext): ir.Expression = {
val name = ctx.funcId.getText
val args = Option(ctx.expression()).map(_.asScala.map(_.accept(this))).getOrElse(Seq.empty)
FunctionBuilder.buildFunction(name, args)
functionBuilder.buildFunction(name, args)
}

// Note that this visitor is made complicated and difficult because the built in ir does not use options
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package com.databricks.labs.remorph.parsers.tsql

import com.databricks.labs.remorph.parsers.{ConversionStrategy, FunctionBuilder, FunctionDefinition, StringConverter, intermediate => ir}

class TSqlFunctionBuilder extends FunctionBuilder {

private val tSqlFunctionDefinitionPf: PartialFunction[String, FunctionDefinition] = {
case "@@CURSOR_STATUS" => FunctionDefinition.notConvertible(0)
case "@@FETCH_STATUS" => FunctionDefinition.notConvertible(0)
// The ConversionStrategy is used to rename ISNULL to IFNULL
case "ISNULL" => FunctionDefinition.standard(2).withConversionStrategy(TSqlFunctionConverters.FunctionRename)
case "MODIFY" => FunctionDefinition.xml(1)
}

override def functionDefinition(name: String): Option[FunctionDefinition] =
// If not found, check common functions
tSqlFunctionDefinitionPf.lift(name.toUpperCase()).orElse(super.functionDefinition(name))

def applyConversionStrategy(
functionArity: FunctionDefinition,
args: Seq[ir.Expression],
irName: String): ir.Expression = {
functionArity.conversionStrategy match {
case Some(strategy) => strategy.convert(irName, args)
case _ => ir.CallFunction(irName, args)
}
}
}

// TSQL specific function converters
//
// Note that these are left as objects, though we will possibly have a class per function in the future
// Each function can specify its own ConversionStrategy, and some will need to be very specific,
// hence perhaps moving to a class per function may be a better idea.
object TSqlFunctionConverters {

object FunctionRename extends ConversionStrategy with StringConverter {
override def convert(irName: String, args: Seq[ir.Expression]): ir.Expression = {
irName.toUpperCase() match {
case "ISNULL" => ir.CallFunction(convertString(irName, "IFNULL"), args)
case _ => ir.CallFunction(irName, args)
}
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class TSqlRelationBuilder extends TSqlParserBaseVisitor[ir.Relation] {
// TODO: Process all the other elements of a query specification

val columns =
ctx.selectListElem().asScala.map(_.accept(new TSqlExpressionBuilder()))
ctx.selectListElem().asScala.map(_.accept(new TSqlExpressionBuilder(new TSqlFunctionBuilder)))
val from = Option(ctx.tableSources()).map(_.accept(new TSqlRelationBuilder)).getOrElse(ir.NoTable)
// Note that ALL is the default so we don't need to check for it
ctx match {
Expand Down Expand Up @@ -100,7 +100,7 @@ class TSqlRelationBuilder extends TSqlParserBaseVisitor[ir.Relation] {
private def buildJoin(left: ir.Relation, right: JoinPartContext): ir.Join = {
val joinExpression = right.joinOn()
val rightRelation = joinExpression.tableSource().accept(this)
val joinCondition = joinExpression.searchCondition().accept(new TSqlExpressionBuilder)
val joinCondition = joinExpression.searchCondition().accept(new TSqlExpressionBuilder(new TSqlFunctionBuilder))

ir.Join(
left,
Expand Down
Loading

0 comments on commit eac0207

Please sign in to comment.