Skip to content

Commit

Permalink
feat: Support function translation in the TSQl and Snowflake
Browse files Browse the repository at this point in the history
  - Some functions must be translated from TSQL or Snowflake versions into the
    equivalent IR for Databricks SQL.
  - Provides a ConversionStrategy system that allows for simple name translations
    or more complicated IR representations when there is no equivalent.
  - Modifies FunctionBuilders etc to accept differnt dialects.
  • Loading branch information
jimidle committed May 30, 2024
1 parent 5ffc7eb commit f56ed2f
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,14 @@ case object UnknownFunction extends FunctionType

sealed trait FunctionArity {
def isConvertible: Boolean
def conversionStrategy: Option[ConversionStrategy]
}

case class FixedArity(arity: Int, functionType: FunctionType = StandardFunction, convertible: Boolean = true)
case class FixedArity(
arity: Int,
functionType: FunctionType = StandardFunction,
convertible: Boolean = true,
override val conversionStrategy: Option[ConversionStrategy] = None)
extends FunctionArity {
override def isConvertible: Boolean = convertible
}
Expand All @@ -22,7 +27,8 @@ case class VariableArity(
argMin: Int,
argMax: Int,
functionType: FunctionType = StandardFunction,
convertible: Boolean = true)
convertible: Boolean = true,
override val conversionStrategy: Option[ConversionStrategy] = None)
extends FunctionArity {
override def isConvertible: Boolean = convertible
}
Expand Down Expand Up @@ -142,7 +148,9 @@ object FunctionBuilder {
case "ISDATE" => Some(FixedArity(1))
case "ISDESCENDANTOF" => Some(FixedArity(1))
case "ISJSON" => Some(VariableArity(1, 2))
case "ISNULL" => Some(FixedArity(2))
// Though ISNULL is FixedArity, it is one argument in Snowflake (and others) but two arguments in TSql
// The ConversionStrategy is used to rename ISNULL to IFNULL in TSql
case "ISNULL" => Some(VariableArity(1, 2, conversionStrategy = Some(FunctionConverters.FunctionRename)))
case "ISNUMERIC" => Some(FixedArity(1))
case "JSON_MODIFY" => Some(FixedArity(3))
case "JSON_PATH_EXISTS" => Some(FixedArity(2))
Expand Down Expand Up @@ -258,7 +266,7 @@ object FunctionBuilder {
}
}

def buildFunction(name: String, args: Seq[ir.Expression]): ir.Expression = {
def buildFunction(name: String, args: Seq[ir.Expression], dialect: SqlDialect): ir.Expression = {
val irName = removeQuotesAndBrackets(name)
val uName = irName.toUpperCase(Locale.getDefault())
val defnOption = functionArity(uName)
Expand All @@ -268,11 +276,11 @@ object FunctionBuilder {
ir.UnresolvedFunction(name, args, is_distinct = false, is_user_defined_function = false)

case Some(fixedArity: FixedArity) if args.length == fixedArity.arity =>
ir.CallFunction(irName, args)
applyConversionStrategy(fixedArity, args, irName, dialect)

case Some(variableArity: VariableArity)
if args.length >= variableArity.argMin && args.length <= variableArity.argMax =>
ir.CallFunction(irName, args)
applyConversionStrategy(variableArity, args, irName, dialect)

// Found the function but the arg count is incorrect
case Some(_) =>
Expand All @@ -289,6 +297,17 @@ object FunctionBuilder {
}
}

def applyConversionStrategy(
functionArity: FunctionArity,
args: Seq[ir.Expression],
irName: String,
dialect: SqlDialect): ir.Expression = {
functionArity.conversionStrategy match {
case Some(strategy) => strategy.convert(irName, args, dialect)
case _ => ir.CallFunction(irName, args)
}
}

/**
* Functions can be called even if they are quoted or bracketed. This function removes the quotes and brackets.
* @param str
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package com.databricks.labs.remorph.parsers
import com.databricks.labs.remorph.parsers.{intermediate => ir}

import java.util.Locale

// A set of functions that convert between different representations of SQL functions, sometimes
// depending on the SQL dialect. They could be one big match statement if all the conversions are trivial renames,
// however, if the conversions are more complex, they can be implemented as separate conversion strategies.

sealed trait ConversionStrategy {
def convert(irName: String, args: Seq[ir.Expression], dialect: SqlDialect): ir.Expression
}

object FunctionConverters {

// Preserves case if the original name was all lower case. Otherwise, converts to lower case.
// All bets are off if the original name was mixed case, but that is rarely seen in SQL and we are
// just making reasonable efforts here.
private def convertString(irName: String, newName: String): String = {
if (irName.forall(_.isLower)) newName.toLowerCase(Locale.ROOT) else newName
}

object FunctionRename extends ConversionStrategy {
override def convert(irName: String, args: Seq[ir.Expression], dialect: SqlDialect): ir.Expression = {
(irName.toUpperCase(), dialect) match {
case ("ISNULL", TSql) => ir.CallFunction(FunctionConverters.convertString(irName, "IFNULL"), args)
case _ => ir.CallFunction(irName, args)
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,7 @@ trait ParserCommon {
a != null && b != null && a.getSourceInterval.startsBeforeDisjoint(b.getSourceInterval)
}
}

sealed trait SqlDialect
case object TSql extends SqlDialect
case object Snowflake extends SqlDialect
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.databricks.labs.remorph.parsers.tsql

import com.databricks.labs.remorph.parsers.tsql.TSqlParser._
import com.databricks.labs.remorph.parsers.{FunctionBuilder, ParserCommon, StandardFunction, UnknownFunction, XmlFunction, intermediate => ir}
import com.databricks.labs.remorph.parsers.{FunctionBuilder, ParserCommon, StandardFunction, TSql, UnknownFunction, XmlFunction, intermediate => ir}
import org.antlr.v4.runtime.Token
import org.antlr.v4.runtime.tree.{TerminalNode, Trees}

Expand Down Expand Up @@ -281,7 +281,7 @@ class TSqlExpressionBuilder extends TSqlParserBaseVisitor[ir.Expression] with Pa
override def visitStandardFunction(ctx: StandardFunctionContext): ir.Expression = {
val name = ctx.funcId.getText
val args = Option(ctx.expression()).map(_.asScala.map(_.accept(this))).getOrElse(Seq.empty)
FunctionBuilder.buildFunction(name, args)
FunctionBuilder.buildFunction(name, args, TSql)
}

// Note that this visitor is made complicated and difficult because the built in ir does not use options
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.databricks.labs.remorph.parsers.common

import com.databricks.labs.remorph.parsers.intermediate.UnresolvedFunction
import com.databricks.labs.remorph.parsers.{FixedArity, FunctionArity, FunctionBuilder, FunctionType, StandardFunction, UnknownFunction, VariableArity, XmlFunction}
import com.databricks.labs.remorph.parsers.{intermediate => ir, _}
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.prop.TableDrivenPropertyChecks
Expand Down Expand Up @@ -126,7 +125,7 @@ class FunctionBuilderSpec extends AnyFlatSpec with Matchers with TableDrivenProp
("ISDATE", Some(FixedArity(1))),
("ISDESCENDANTOF", Some(FixedArity(1))),
("ISJSON", Some(VariableArity(1, 2))),
("ISNULL", Some(FixedArity(2))),
("ISNULL", Some(VariableArity(1, 2, conversionStrategy = Some(FunctionConverters.FunctionRename)))),
("ISNUMERIC", Some(FixedArity(1))),
("JSON_MODIFY", Some(FixedArity(3))),
("JSON_PATH_EXISTS", Some(FixedArity(2))),
Expand Down Expand Up @@ -283,39 +282,71 @@ class FunctionBuilderSpec extends AnyFlatSpec with Matchers with TableDrivenProp

"buildFunction" should "remove quotes and brackets from function names" in {
// Test function name with less than 2 characters
val result1 = FunctionBuilder.buildFunction("a", Seq())
val result1 = FunctionBuilder.buildFunction("a", List.empty, TSql)
result1 match {
case f: UnresolvedFunction => f.function_name shouldBe "a"
case f: ir.UnresolvedFunction => f.function_name shouldBe "a"
case _ => fail("Unexpected function type")
}

// Test function name with matching quotes
val result2 = FunctionBuilder.buildFunction("'quoted'", Seq())
val result2 = FunctionBuilder.buildFunction("'quoted'", List.empty, TSql)
result2 match {
case f: UnresolvedFunction => f.function_name shouldBe "quoted"
case f: ir.UnresolvedFunction => f.function_name shouldBe "quoted"
case _ => fail("Unexpected function type")
}

// Test function name with matching brackets
val result3 = FunctionBuilder.buildFunction("[bracketed]", Seq())
val result3 = FunctionBuilder.buildFunction("[bracketed]", List.empty, TSql)
result3 match {
case f: UnresolvedFunction => f.function_name shouldBe "bracketed"
case f: ir.UnresolvedFunction => f.function_name shouldBe "bracketed"
case _ => fail("Unexpected function type")
}

// Test function name with matching backslashes
val result4 = FunctionBuilder.buildFunction("\\backslashed\\", Seq())
val result4 = FunctionBuilder.buildFunction("\\backslashed\\", List.empty, TSql)
result4 match {
case f: UnresolvedFunction => f.function_name shouldBe "backslashed"
case f: ir.UnresolvedFunction => f.function_name shouldBe "backslashed"
case _ => fail("Unexpected function type")
}

// Test function name with non-matching quotes
val result5 = FunctionBuilder.buildFunction("'nonmatching", Seq())
val result5 = FunctionBuilder.buildFunction("'nonmatching", List.empty, TSql)
result5 match {
case f: UnresolvedFunction => f.function_name shouldBe "'nonmatching"
case f: ir.UnresolvedFunction => f.function_name shouldBe "'nonmatching"
case _ => fail("Unexpected function type")
}
}

"buildFunction" should "Apply known TSQL conversion strategies" in {
val result1 = FunctionBuilder.buildFunction("ISNULL", Seq(ir.Column("x"), ir.Literal(integer = Some(0))), TSql)
result1 match {
case f: ir.CallFunction => f.function_name shouldBe "IFNULL"
case _ => fail("ISNULL TSql conversion failed")
}
}

"buildFunction" should "Ignore TSQL conversion strategies in Snowflake dialect" in {
val result1 = FunctionBuilder.buildFunction("ISNULL", Seq(ir.Column("x"), ir.Literal(integer = Some(0))), Snowflake)
result1 match {
case f: ir.CallFunction => f.function_name shouldBe "ISNULL"
case _ => fail("ISNULL Snowflake conversion failed")
}
}

"buildFunction" should "Should preserve case if it can" in {
val result1 = FunctionBuilder.buildFunction("isnull", Seq(ir.Column("x"), ir.Literal(integer = Some(0))), TSql)
result1 match {
case f: ir.CallFunction => f.function_name shouldBe "ifnull"
case _ => fail("ifnull conversion failed")
}
}

"FunctionRename strategy" should "preserve original function if no match is found" in {
val result1 =
FunctionConverters.FunctionRename.convert("Abs", Seq(ir.Literal(integer = Some(66))), TSql)
result1 match {
case f: ir.CallFunction => f.function_name shouldBe "Abs"
case _ => fail("UNKNOWN_FUNCTION conversion failed")
}
}
}

0 comments on commit f56ed2f

Please sign in to comment.