Skip to content

Commit

Permalink
feat: Implement TSQL specific function call mapper
Browse files Browse the repository at this point in the history
  - Implements the mapper and tests for transformation rules on TSQL specific
    functions.

TODO: Manually test translations in Databricks workspace
  • Loading branch information
jimidle committed Aug 2, 2024
1 parent fed469e commit c919c2c
Show file tree
Hide file tree
Showing 19 changed files with 238 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,28 @@ class ExpressionGenerator(val callMapper: ir.CallMapper = new ir.CallMapper())
case ua: ir.UpdateAction => updateAction(ctx, ua)
case a: ir.Assign => assign(ctx, a)
case opts: ir.Options => options(ctx, opts)
case i: ir.KnownInterval => interval(ctx, i)

case x => throw TranspileException(s"Unsupported expression: $x")
}
}

private def interval(ctx: GeneratorContext, interval: ir.KnownInterval): String = {
val iType = interval.iType match {
case ir.YEAR_INTERVAL => "YEAR"
case ir.MONTH_INTERVAL => "MONTH"
case ir.WEEK_INTERVAL => "WEEK"
case ir.DAY_INTERVAL => "DAY"
case ir.HOUR_INTERVAL => "HOUR"
case ir.MINUTE_INTERVAL => "MINUTE"
case ir.SECOND_INTERVAL => "SECOND"
case ir.MILLISECOND_INTERVAL => "MILLISECOND"
case ir.MICROSECOND_INTERVAL => "MICROSECOND"
case ir.NANOSECOND_INTERVAL => "NANOSECOND"
}
s"INTERVAL ${generate(ctx, interval.value)} ${iType}"
}

private def options(ctx: GeneratorContext, opts: ir.Options): String = {
// First gather the options that are set by expressions
val exprOptions = opts.expressionOpts.map { case (key, expression) =>
Expand Down Expand Up @@ -154,6 +171,9 @@ class ExpressionGenerator(val callMapper: ir.CallMapper = new ir.CallMapper())
call match {
case r: RLike => rlike(ctx, r)
case fn: ir.Fn => s"${fn.prettyName}(${fn.children.map(expression(ctx, _)).mkString(", ")})"

// Certain functions can be translated to SQL functions directly
case e: ir.Expression => expression(ctx, e)
case _ => throw TranspileException("not implemented")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,26 @@ case class WithOptions(input: LogicalPlan, options: Expression) extends UnaryNod
override def child: LogicalPlan = input
override def output: Seq[Attribute] = input.output
}

// Though at least TSQL only needs the time based intervals, we are including all the interval types
// supported by Spark SQL for completeness and future proofing
sealed trait KnownIntervalType
case object NANOSECOND_INTERVAL extends KnownIntervalType
case object MICROSECOND_INTERVAL extends KnownIntervalType
case object MILLISECOND_INTERVAL extends KnownIntervalType
case object SECOND_INTERVAL extends KnownIntervalType
case object MINUTE_INTERVAL extends KnownIntervalType
case object HOUR_INTERVAL extends KnownIntervalType
case object DAY_INTERVAL extends KnownIntervalType
case object WEEK_INTERVAL extends KnownIntervalType
case object MONTH_INTERVAL extends KnownIntervalType
case object YEAR_INTERVAL extends KnownIntervalType

// TSQL - For translation purposes, we cannot use teh standard Catalyst CalendarInterval as it is not
// meant for code generation and converts everything to microseconds. It is much easier to use an extension
// to the AST to represent the interval as it is required in TSQL, where we need to know if we were dealing with
// MONTHS, HOURS, etc.
case class KnownInterval(value: Expression, iType: KnownIntervalType) extends Expression {
override def children: Seq[Expression] = Seq(value)
override def dataType: DataType = UnresolvedType
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package com.databricks.labs.remorph.parsers.tsql.rules

import com.databricks.labs.remorph.parsers.intermediate._

class TSqlCallMapper extends CallMapper {

override def convert(call: Fn): Expression = {
call match {
case CallFunction("DATEADD", args) =>
processDateAdd(args)
case x: CallFunction => super.convert(x)
}
}

private def processDateAdd(args: Seq[Expression]): Expression = {

// The first argument of the TSQL DATEADD function is the interval type, which is one of way too
// many strings and aliases for "day", "month", "year", etc. We need to extract this string and
// perform the translation based on what we get
val interval = args.head match {
case Column(_, id) => id.id.toLowerCase()
case _ =>
throw new IllegalArgumentException("DATEADD interval type is not valid. Should be 'day', 'month', 'year', etc.")
}

// The value is how many units, type indicated by interval, to add to the date
val value = args(1)

// And this is the thing we are going to add the value to
val objectReference = args(2)

// The interval type names are all over the place in TSQL, some of them having names that
// belie their actual function.
interval match {

// Days are all that Spark DATE_ADD operates on, but the arguments are transposed from TSQL
// despite the fact that 'dayofyear' implies the number of the day in the year, it is in fact the
// same as day, as is `weekday`
case "day" | "dayofyear" | "dd" | "d" | "dy" | "y" | "weekday" | "dw" | "w" =>
DateAdd(objectReference, value)

// Months are handled by the MonthAdd function, with arguments transposed from TSQL
case "month" | "mm" | "m" => AddMonths(objectReference, value)

// There is no equivalent to quarter in Spark, so we have to use the MonthAdd function and multiply by 3
case "quarter" | "qq" | "q" => AddMonths(objectReference, Multiply(value, Literal(3)))

// There is no equivalent to year in Spark SQL, but we use months and multiply by 12
case "year" | "yyyy" | "yy" => AddMonths(objectReference, Multiply(value, Literal(12)))

// Weeks are not supported in Spark SQL, but we can multiply by 7 to get the same effect with DATE_ADD
case "week" | "wk" | "ww" => DateAdd(objectReference, Multiply(value, Literal(7)))

// Hours are not supported in Spark SQL, but we can use the number of hours to create an INTERVAL
// and add it to the object reference
case "hour" | "hh" => Add(objectReference, KnownInterval(value, HOUR_INTERVAL))

// Minutes are not supported in Spark SQL, but we can use the number of minutes to create an INTERVAL
// and add it to the object reference
case "minute" | "mi" | "n" => Add(objectReference, KnownInterval(value, MINUTE_INTERVAL))

// Seconds are not supported in Spark SQL, but we can use the number of seconds to create an INTERVAL
// and add it to the object reference
case "second" | "ss" | "s" => Add(objectReference, KnownInterval(value, SECOND_INTERVAL))

// Milliseconds are not supported in Spark SQL, but we can use the number of milliseconds to create an INTERVAL
// and add it to the object reference
case "millisecond" | "ms" => Add(objectReference, KnownInterval(value, MILLISECOND_INTERVAL))

// Microseconds are not supported in Spark SQL, but we can use the number of microseconds to create an INTERVAL
// and add it to the object reference
case "microsecond" | "mcs" => Add(objectReference, KnownInterval(value, MICROSECOND_INTERVAL))

// Nanoseconds are not supported in Spark SQL, but we can use the number of nanoseconds to create an INTERVAL
// and add it to the object reference
case "nanosecond" | "ns" => Add(objectReference, KnownInterval(value, NANOSECOND_INTERVAL))

case _ =>
throw new IllegalArgumentException(
s"DATEADD interval type '${interval}' is not valid. Should be 'day', 'month', 'year', etc.")
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@ package com.databricks.labs.remorph.transpilers

import com.databricks.labs.remorph.generators.GeneratorContext
import com.databricks.labs.remorph.generators.sql.{ExpressionGenerator, LogicalPlanGenerator}
import com.databricks.labs.remorph.parsers.tsql.rules.{PullLimitUpwards, TopPercentToLimitSubquery, TrapInsertDefaultsAction}
import com.databricks.labs.remorph.parsers.tsql.rules.{PullLimitUpwards, TSqlCallMapper, TopPercentToLimitSubquery, TrapInsertDefaultsAction}
import com.databricks.labs.remorph.parsers.tsql.{TSqlAstBuilder, TSqlErrorStrategy, TSqlLexer, TSqlParser}
import com.databricks.labs.remorph.parsers.{ProductionErrorCollector, intermediate => ir}
import org.antlr.v4.runtime.{CharStreams, CommonTokenStream}

class TSqlToDatabricksTranspiler extends BaseTranspiler {
private val astBuilder = new TSqlAstBuilder()
private val optimizer = ir.Rules(PullLimitUpwards, new TopPercentToLimitSubquery, TrapInsertDefaultsAction)
private val generator = new LogicalPlanGenerator(new ExpressionGenerator())
private val generator = new LogicalPlanGenerator(new ExpressionGenerator(new TSqlCallMapper()))

override def parse(input: String): ir.LogicalPlan = {
val inputString = CharStreams.fromString(input)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
package com.databricks.labs.remorph.parsers.tsql

import com.databricks.labs.remorph.generators.sql.{ExpressionGenerator, GeneratorTestCommon}
import com.databricks.labs.remorph.parsers.intermediate.IRHelpers
import com.databricks.labs.remorph.parsers.tsql.rules.TSqlCallMapper
import com.databricks.labs.remorph.parsers.{intermediate => ir}
import org.scalatest.wordspec.AnyWordSpec
import org.scalatestplus.mockito.MockitoSugar

// Only add tests here that require the TSqlCallMapper, or in the future any other transformer/rule
// that is specific to T-SQL. Otherwise they belong in ExpressionGeneratorTest.

class TSqlExpressionGeneratorTest
extends AnyWordSpec
with GeneratorTestCommon[ir.Expression]
with MockitoSugar
with IRHelpers {

override protected val generator = new ExpressionGenerator(new TSqlCallMapper)

"DATEADD" should {
"transpile to DATE_ADD" in {
ir.CallFunction(
"DATEADD",
Seq(simplyNamedColumn("day"), ir.Literal(42.toShort), simplyNamedColumn("col1"))) generates "DATE_ADD(col1, 42)"

ir.CallFunction(
"DATEADD",
Seq(
simplyNamedColumn("week"),
ir.Literal(42.toShort),
simplyNamedColumn("col1"))) generates "DATE_ADD(col1, 42 * 7)"
}

"transpile to ADD_MONTHS" in {
ir.CallFunction(
"DATEADD",
Seq(
simplyNamedColumn("Month"),
ir.Literal(42.toShort),
simplyNamedColumn("col1"))) generates "ADD_MONTHS(col1, 42)"

ir.CallFunction(
"DATEADD",
Seq(
simplyNamedColumn("qq"),
ir.Literal(42.toShort),
simplyNamedColumn("col1"))) generates "ADD_MONTHS(col1, 42 * 3)"
}

"transpile to INTERVAL" in {
ir.CallFunction(
"DATEADD",
Seq(
simplyNamedColumn("hour"),
ir.Literal(42.toShort),
simplyNamedColumn("col1"))) generates "col1 + INTERVAL 42 HOUR"

ir.CallFunction(
"DATEADD",
Seq(
simplyNamedColumn("minute"),
ir.Literal(42.toShort),
simplyNamedColumn("col1"))) generates "col1 + INTERVAL 42 MINUTE"

ir.CallFunction(
"DATEADD",
Seq(
simplyNamedColumn("second"),
ir.Literal(42.toShort),
simplyNamedColumn("col1"))) generates "col1 + INTERVAL 42 SECOND"

ir.CallFunction(
"DATEADD",
Seq(
simplyNamedColumn("millisecond"),
ir.Literal(42.toShort),
simplyNamedColumn("col1"))) generates "col1 + INTERVAL 42 MILLISECOND"

ir.CallFunction(
"DATEADD",
Seq(
simplyNamedColumn("mcs"),
ir.Literal(42.toShort),
simplyNamedColumn("col1"))) generates "col1 + INTERVAL 42 MICROSECOND"

ir.CallFunction(
"DATEADD",
Seq(
simplyNamedColumn("ns"),
ir.Literal(42.toShort),
simplyNamedColumn("col1"))) generates "col1 + INTERVAL 42 NANOSECOND"
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(hour, 7, col1) AS add_hours_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 HOUR) AS add_hours_col1 FROM tabl;
SELECT col1 + INTERVAL 7 HOUR AS add_hours_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(hh, 7, col1) AS add_hours_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 HOUR) AS add_hours_col1 FROM tabl;
SELECT col1 + INTERVAL 7 HOUR AS add_hours_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(MICROSECOND, 7, col1) AS add_microsecond_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 MICROSECOND) AS add_microsecond_col1 FROM tabl;
SELECT col1 + INTERVAL 7 MICROSECOND AS add_microsecond_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(mcs, 7, col1) AS add_microsecond_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 MICROSECOND) AS add_microsecond_col1 FROM tabl;
SELECT col1 + INTERVAL 7 MICROSECOND) AS add_microsecond_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(millisecond, 7, col1) AS add_minutes_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 MILLISECOND) AS add_minutes_col1 FROM tabl;
SELECT col1 + INTERVAL 7 MILLISECOND) AS add_minutes_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(ms, 7, col1) AS add_milliseconds_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 MILLISECOND) AS add_milliseconds_col1 FROM tabl;
SELECT col1 + INTERVAL 7 MILLISECOND) AS add_milliseconds_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(minute, 7, col1) AS add_minutes_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 MINUTE) AS add_minutes_col1 FROM tabl;
SELECT col1 + INTERVAL 7 MINUTE AS add_minutes_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(mi, 7, col1) AS add_minutes_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 MINUTE) AS add_minutes_col1 FROM tabl;
SELECT col1 + INTERVAL 7 MINUTE) AS add_minutes_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(n, 7, col1) AS add_minutes_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 MINUTE) AS add_minutes_col1 FROM tabl;
SELECT col1 + INTERVAL 7 MINUTE) AS add_minutes_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(NANOSECOND, 7, col1) AS add_nanoseconds_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 NANOSECOND) AS add_nanoseconds_col1 FROM tabl;
SELECT col1 + INTERVAL 7 NANOSECOND) AS add_nanoseconds_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(NS, 7, col1) AS add_minutes_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 NANOSECOND) AS add_minutes_col1 FROM tabl;
SELECT col1 + INTERVAL 7 NANOSECOND) AS add_minutes_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(second, 7, col1) AS add_seconds_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 SECOND) AS add_seconds_col1 FROM tabl;
SELECT col1 + INTERVAL 7 SECOND) AS add_seconds_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(ss, 7, col1) AS add_seconds_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 SECOND) AS add_seconds_col1 FROM tabl;
SELECT col1 + INTERVAL 7 SECOND) AS add_seconds_col1 FROM tabl;
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@
SELECT DATEADD(s, 7, col1) AS add_minutes_col1 FROM tabl;

-- databricks sql:
SELECT (col1 + INTERVAL 7 SECOND) AS add_minutes_col1 FROM tabl;
SELECT col1 + INTERVAL 7 SECOND) AS add_minutes_col1 FROM tabl;

0 comments on commit c919c2c

Please sign in to comment.