Skip to content

Commit

Permalink
[SQL Generator] Generate SQL code for literals (#706)
Browse files Browse the repository at this point in the history
Co-authored-by: Valentin Kasas <[email protected]>
  • Loading branch information
nfx and vil1 authored Jul 24, 2024
1 parent e90eecf commit ec13a9e
Show file tree
Hide file tree
Showing 7 changed files with 194 additions and 47 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.databricks.labs.remorph.generators

case class GeneratorContext(
maxLineWidth: Int = 120,
private val indent: Int = 0,
private val layer: Int = 0,
private val joins: Int = 0,
wrapLiteral: Boolean = true) {
def nest: GeneratorContext =
GeneratorContext(maxLineWidth = maxLineWidth, joins = joins, layer = layer, indent = indent + 1)

def ws: String = " " * indent

def subQuery: GeneratorContext =
GeneratorContext(maxLineWidth = maxLineWidth, joins = joins, layer = layer + 1, indent = indent + 1)

def layerName: String = s"layer_$layer"

def withRawLiteral: GeneratorContext =
GeneratorContext(maxLineWidth = maxLineWidth, joins = joins, indent = indent, layer = layer, wrapLiteral = false)

def hasJoins: Boolean = joins > 0
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package com.databricks.labs.remorph.generators.sql

import com.databricks.labs.remorph.generators.GeneratorContext
import com.databricks.labs.remorph.parsers.intermediate.Literal
import com.databricks.labs.remorph.parsers.{intermediate => ir}

import java.text.SimpleDateFormat
import java.util.Locale

class ExpressionGenerator {
private val dateFormat = new SimpleDateFormat("yyyy-MM-dd")
private val timeFormat = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss.SSS")

def expression(ctx: GeneratorContext, expr: ir.Expression): String = {
expr match {
case l: ir.Literal => literal(ctx, l)
case _ => throw new IllegalArgumentException(s"Unsupported expression: $expr")
}
}

private def literal(ctx: GeneratorContext, l: Literal): String = {
l.dataType match {
case ir.NullType => "NULL"
case ir.BinaryType => orNull(l.binary.map(_.map("%02X" format _).mkString))
case ir.BooleanType => orNull(l.boolean.map(_.toString.toUpperCase(Locale.getDefault)))
case ir.ShortType => orNull(l.short.map(_.toString))
case ir.IntegerType => orNull(l.integer.map(_.toString))
case ir.LongType => orNull(l.long.map(_.toString))
case ir.FloatType => orNull(l.float.map(_.toString))
case ir.DoubleType => orNull(l.double.map(_.toString))
case ir.StringType => orNull(l.string.map(doubleQuote))
case ir.DateType =>
l.date match {
case Some(date) => doubleQuote(dateFormat.format(date))
case None => "NULL"
}
case ir.TimestampType =>
l.timestamp match {
case Some(timestamp) => doubleQuote(timeFormat.format(timestamp))
case None => "NULL"
}
case ir.ArrayType(_) => orNull(l.array.map(arrayExpr(ctx)))
case ir.MapType(_, _) => orNull(l.map.map(mapExpr(ctx)))
case _ => throw new IllegalArgumentException(s"Unsupported expression: ${l.dataType}")
}
}

private def mapExpr(ctx: GeneratorContext)(map: ir.MapExpr): String = {
val entries = map.keys.zip(map.values).map { case (key, value) =>
s"${literal(ctx, key)}, ${expression(ctx, value)}"
}
// TODO: line-width formatting
s"MAP(${entries.mkString(", ")})"
}

private def arrayExpr(ctx: GeneratorContext)(array: ir.ArrayExpr): String = {
val elements = array.elements.map { element =>
expression(ctx, element)
}
// TODO: line-width formatting
s"ARRAY(${elements.mkString(", ")})"
}

private def orNull(option: Option[String]): String = option.getOrElse("NULL")

private def doubleQuote(s: String): String = s""""$s""""
}
Original file line number Diff line number Diff line change
Expand Up @@ -147,24 +147,47 @@ case class CalendarInterval(months: Int, days: Int, microseconds: Long) extends
override def dataType: DataType = CalendarIntervalType
}

case class ArrayExpr(dataType: DataType, elements: Seq[Literal]) extends Expression {
case class ArrayExpr(dataType: DataType, elements: Seq[Expression]) extends Expression {
override def children: Seq[Expression] = elements
}

case class JsonExpr(dataType: DataType, fields: Seq[(String, Literal)]) extends Expression {
override def children: Seq[Expression] = fields.map(_._2)
}

case class MapExpr(key_type: DataType, value_type: DataType, keys: Seq[Literal], values: Seq[Literal])
case class MapExpr(key_type: DataType, value_type: DataType, keys: Seq[Literal], values: Seq[Expression])
extends Expression {
override def children: Seq[Expression] = keys ++ values
override def dataType: DataType = MapType() // TODO: Fix this
override def dataType: DataType = MapType(key_type, value_type)
}

case class Struct(dataType: DataType, elements: Seq[Literal]) extends Expression {
override def children: Seq[Expression] = elements
}

object Literal {
def apply(value: Array[Byte]): Literal = Literal(binary = Some(value))
def apply(value: Boolean): Literal = Literal(boolean = Some(value))
def apply(value: Byte): Literal = Literal(byte = Some(value.intValue()))
def apply(value: Short): Literal = Literal(short = Some(value.intValue()))
def apply(value: Int): Literal = Literal(integer = Some(value))
def apply(value: Long): Literal = Literal(long = Some(value))
def apply(value: Float): Literal = Literal(float = Some(value))
def apply(value: Decimal): Literal = Literal(decimal = Some(value))
def apply(value: Double): Literal = Literal(double = Some(value))
def apply(value: String): Literal = Literal(string = Some(value))
def apply(value: java.sql.Date): Literal = Literal(date = Some(value.getTime))
def apply(value: java.sql.Timestamp): Literal = Literal(timestamp = Some(value.getTime))
def apply(value: Map[String, Expression]): Literal = Literal(map = Some(
MapExpr(
StringType,
value.values.head.dataType,
value.keys.map(key => Literal(string = Some(key))).toSeq,
value.values.toSeq)))
def apply(values: Seq[Expression]): Literal =
Literal(array = Some(ArrayExpr(values.headOption.map(_.dataType).getOrElse(UnresolvedType), values)))
}

case class Literal(
nullType: Option[DataType] = None,
binary: Option[Array[Byte]] = None,
Expand All @@ -187,47 +210,30 @@ case class Literal(
map: Option[MapExpr] = None,
json: Option[JsonExpr] = None)
extends LeafExpression {

override def dataType: DataType = {
if (binary.isDefined) {
BinaryType
} else if (boolean.isDefined) {
BooleanType
} else if (byte.isDefined) {
ByteType(byte)
} else if (short.isDefined) {
ShortType
} else if (integer.isDefined) {
IntegerType
} else if (long.isDefined) {
LongType
} else if (float.isDefined) {
FloatType
} else if (decimal.isDefined) {
DecimalType(decimal.get.precision, decimal.get.scale)
} else if (double.isDefined) {
DoubleType
} else if (string.isDefined) {
StringType
} else if (date.isDefined) {
DateType
} else if (timestamp.isDefined) {
TimestampType
} else if (timestamp_ntz.isDefined) {
TimestampNTZType
} else if (calendar_interval.isDefined) {
CalendarIntervalType
} else if (year_month_interval.isDefined) {
YearMonthIntervalType
} else if (day_time_interval.isDefined) {
DayTimeIntervalType
} else if (array.isDefined) {
ArrayType()
} else if (map.isDefined) {
MapType()
} else if (json.isDefined) {
UDTType()
} else {
NullType
this match {
case _ if binary.isDefined => BinaryType
case _ if boolean.isDefined => BooleanType
case _ if byte.isDefined => ByteType(byte)
case _ if short.isDefined => ShortType
case _ if integer.isDefined => IntegerType
case _ if long.isDefined => LongType
case _ if float.isDefined => FloatType
case _ if decimal.isDefined => DecimalType(decimal.get.precision, decimal.get.scale)
case _ if double.isDefined => DoubleType
case _ if string.isDefined => StringType
case _ if date.isDefined => DateType
case _ if timestamp.isDefined => TimestampType
case _ if timestamp_ntz.isDefined => TimestampNTZType
case _ if calendar_interval.isDefined => CalendarIntervalType
case _ if year_month_interval.isDefined => YearMonthIntervalType
case _ if day_time_interval.isDefined => DayTimeIntervalType
case _ if array.isDefined => ArrayType(array.map(_.dataType).getOrElse(UnresolvedType))
case _ if map.isDefined =>
MapType(map.map(_.key_type).getOrElse(UnresolvedType), map.map(_.value_type).getOrElse(UnresolvedType))
case _ if json.isDefined => UDTType()
case _ => NullType
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ case object YearMonthIntervalType extends DataType
case object DayTimeIntervalType extends DataType

// Complex types
case class ArrayType() extends DataType
case class ArrayType(elementType: DataType) extends DataType
case class StructType() extends DataType
case class MapType() extends DataType
case class MapType(keyType: DataType, valueType: DataType) extends DataType

// UserDefinedType
case class UDTType() extends DataType
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ object DataTypeBuilder {
case c if c.charAlias != null => ir.CharType(sizeOpt)
case c if c.varcharAlias != null => ir.VarCharType(sizeOpt)
case c if c.binaryAlias != null => ir.BinaryType
case c if c.ARRAY() != null => ir.ArrayType()
case c if c.ARRAY() != null => ir.ArrayType(ir.UnresolvedType)
case _ => ir.UnparsedType()
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
package com.databricks.labs.remorph.generators.sql

import com.databricks.labs.remorph.generators.GeneratorContext
import org.scalatest.matchers.should.Matchers
import org.scalatest.wordspec.AnyWordSpec
import org.scalatestplus.mockito.MockitoSugar
import com.databricks.labs.remorph.parsers.{intermediate => ir}

class ExpressionGeneratorTest extends AnyWordSpec with Matchers with MockitoSugar {
private def generate(expr: ir.Expression): String = {
new ExpressionGenerator().expression(new GeneratorContext(), expr)
}

"literal" should {
"be generated" in {
generate(ir.Literal()) shouldBe "NULL"

generate(ir.Literal(binary = Some(Array(0x01, 0x02, 0x03)))) shouldBe "010203"

generate(ir.Literal(boolean = Some(true))) shouldBe "TRUE"

generate(ir.Literal(short = Some(123))) shouldBe "123"

generate(ir.Literal(integer = Some(123))) shouldBe "123"

generate(ir.Literal(long = Some(123))) shouldBe "123"

generate(ir.Literal(float = Some(123.4f))) shouldBe "123.4"

generate(ir.Literal(double = Some(123.4))) shouldBe "123.4"

generate(ir.Literal(string = Some("abc"))) shouldBe "\"abc\""

generate(ir.Literal(date = Some(1721757801000L))) shouldBe "\"2024-07-23\""

generate(ir.Literal(timestamp = Some(1721757801000L))) shouldBe "\"2024-07-23 18:03:21.000\""
}

"arrays" in {
generate(ir.Literal(Seq(ir.Literal("abc"), ir.Literal("def")))) shouldBe "ARRAY(\"abc\", \"def\")"
}

"maps" in {
generate(
ir.Literal(
Map(
"foo" -> ir.Literal("bar"),
"baz" -> ir.Literal("qux")))) shouldBe "MAP(\"foo\", \"bar\", \"baz\", \"qux\")"
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class DataTypeBuilderSpec extends AnyWordSpec with SnowflakeParserTestCommon wit
example("BOOLEAN", BooleanType)
example("DATE", DateType)
example("BINARY", BinaryType)
example("ARRAY", ArrayType())
example("ARRAY", ArrayType(UnresolvedType))
}

"translate the rest to UnparsedType" in {
Expand Down

0 comments on commit ec13a9e

Please sign in to comment.