Skip to content

Commit

Permalink
Add unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vil1 committed Aug 5, 2024
1 parent d62df32 commit 962db2a
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.databricks.labs.remorph.generators.sql

import com.databricks.labs.remorph.generators.{Generator, GeneratorContext}
import com.databricks.labs.remorph.parsers.intermediate.{FrameBoundary, RLike}
import com.databricks.labs.remorph.parsers.intermediate.RLike
import com.databricks.labs.remorph.parsers.{intermediate => ir}
import com.databricks.labs.remorph.transpilers.TranspileException

Expand Down Expand Up @@ -41,6 +41,7 @@ class ExpressionGenerator(val callMapper: ir.CallMapper = new ir.CallMapper())
case opts: ir.Options => options(ctx, opts)
case i: ir.KnownInterval => interval(ctx, i)
case w: ir.Window => window(ctx, w)
case o: ir.SortOrder => sortOrder(ctx, o)
case x => throw TranspileException(s"Unsupported expression: $x")
}
}
Expand Down Expand Up @@ -277,27 +278,44 @@ class ExpressionGenerator(val callMapper: ir.CallMapper = new ir.CallMapper())
val partition = if (window.partition_spec.isEmpty) { "" }
else { window.partition_spec.map(expression(ctx, _)).mkString("PARTITION BY ", ", ", "") }
val orderBy = if (window.sort_order.isEmpty) { "" }
else { window.sort_order.map(expression(ctx, _)).mkString(" ORDER BY ", ", ", "") }
else { window.sort_order.map(sortOrder(ctx, _)).mkString(" ORDER BY ", ", ", "") }
val windowFrame = window.frame_spec
.map { frame =>
val mode = frame.frame_type match {
case ir.RowsFrame => "ROWS"
case ir.RangeFrame => "RANGE"
}
val boundaries = (frameBoundary(frame.lower) ++ frameBoundary(frame.upper)).mkString(" AND ")
s" $mode $boundaries"
val boundaries = frameBoundary(ctx, frame.lower) ++ frameBoundary(ctx, frame.upper)
val frameBoundaries = if (boundaries.size < 2) { boundaries.mkString }
else { boundaries.mkString("BETWEEN ", " AND ", "") }
s" $mode $frameBoundaries"
}
.getOrElse("")
s"$expr OVER ($partition$orderBy$windowFrame)"
}

private def frameBoundary(boundary: FrameBoundary): Seq[String] = boundary match {
private def frameBoundary(ctx: GeneratorContext, boundary: ir.FrameBoundary): Seq[String] = boundary match {
case ir.NoBoundary => Seq.empty
case ir.CurrentRow => Seq("CURRENT ROW")
case ir.UnboundedPreceding => Seq("UNBOUNDED PRECEDING")
case ir.UnboundedFollowing => Seq("UNBOUNDED FOLLOWING")
case ir.PrecedingN(n) => Seq(s"$n PRECEDING")
case ir.FollowingN(n) => Seq(s"$n FOLLOWING")
case ir.PrecedingN(n) => Seq(s"${expression(ctx, n)} PRECEDING")
case ir.FollowingN(n) => Seq(s"${expression(ctx, n)} FOLLOWING")
}

private def sortOrder(ctx: GeneratorContext, order: ir.SortOrder): String = {
val orderBy = expression(ctx, order.child)
val direction = order.direction match {
case ir.Ascending => Seq("ASC")
case ir.Descending => Seq("DESC")
case ir.UnspecifiedSortDirection => Seq()
}
val nulls = order.nullOrdering match {
case ir.NullsFirst => Seq("NULLS FIRST")
case ir.NullsLast => Seq("NULLS LAST")
case ir.SortNullsUnspecified => Seq()
}
(Seq(orderBy) ++ direction ++ nulls).mkString(" ")
}

private def orNull(option: Option[String]): String = option.getOrElse("NULL")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -860,4 +860,30 @@ class ExpressionGeneratorTest extends AnyWordSpec with GeneratorTestCommon[ir.Ex
ir.Id("table 1", caseSensitive = true)))) generates "schema1.\"table 1\".*"
}
}

"window functions" should {
"be generated" in {
ir.Window(
ir.RowNumber(),
Seq(ir.Id("a")),
Seq(ir.SortOrder(ir.Id("b"), ir.Ascending, ir.NullsFirst)),
Some(
ir.WindowFrame(
ir.RowsFrame,
ir.CurrentRow,
ir.NoBoundary))) generates "ROW_NUMBER() OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST ROWS CURRENT ROW)"

ir.Window(
ir.RowNumber(),
Seq(ir.Id("a")),
Seq(ir.SortOrder(ir.Id("b"), ir.Ascending, ir.NullsFirst)),
Some(
ir.WindowFrame(
ir.RangeFrame,
ir.CurrentRow,
ir.FollowingN(
ir.Literal(42))))) generates "ROW_NUMBER() OVER (PARTITION BY a ORDER BY b ASC NULLS FIRST RANGE BETWEEN CURRENT ROW AND 42 FOLLOWING)"

}
}
}

0 comments on commit 962db2a

Please sign in to comment.