Skip to content

Commit

Permalink
Fixed SELECT TOP X PERCENT IR translation for TSQL (#733)
Browse files Browse the repository at this point in the history
TSQL has `SELECT TOP N * FROM ..` vs `SELECT * FROM .. LIMIT N`, so we
fix it here
  • Loading branch information
nfx authored Jul 28, 2024
1 parent 5c98963 commit e17e64b
Show file tree
Hide file tree
Showing 18 changed files with 331 additions and 78 deletions.
6 changes: 6 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.lihaoyi</groupId>
<artifactId>pprint_${scala.binary.version}</artifactId>
<version>0.8.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.typesafe.scala-logging</groupId>
<artifactId>scala-logging_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package com.databricks.labs.remorph.generators.sql
import com.databricks.labs.remorph.parsers.{intermediate => ir}
import com.databricks.labs.remorph.generators.{Generator, GeneratorContext}
import com.databricks.labs.remorph.parsers.intermediate.{ExceptSetOp, IntersectSetOp, UnionSetOp}
import com.databricks.labs.remorph.transpilers.TranspileException

class LogicalPlanGenerator(val explicitDistinct: Boolean = false) extends Generator[ir.LogicalPlan, String] {

Expand All @@ -18,10 +17,7 @@ class LogicalPlanGenerator(val explicitDistinct: Boolean = false) extends Genera
case ir.NamedTable(id, _, _) => id
case ir.Filter(input, condition) =>
s"${generate(ctx, input)} WHERE ${expr.generate(ctx, condition)}"
case ir.Limit(input, limit, percentage, _) =>
if (percentage) {
throw TranspileException("SELECT TOP .. PERCENT has to be transformed")
}
case ir.Limit(input, limit) =>
s"${generate(ctx, input)} LIMIT ${expr.generate(ctx, limit)}"
case join: ir.Join => generateJoin(ctx, join)
case setOp: ir.SetOperation => setOperation(ctx, setOp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ case class WindowFrame(frame_type: FrameType, lower: FrameBoundary, upper: Frame

case class Window(
window_function: Expression,
partition_spec: Seq[Expression],
sort_order: Seq[SortOrder],
frame_spec: Option[WindowFrame])
partition_spec: Seq[Expression] = Seq.empty,
sort_order: Seq[SortOrder] = Seq.empty,
frame_spec: Option[WindowFrame] = None)
extends Expression {
override def children: Seq[Expression] = window_function +: partition_spec
override def dataType: DataType = window_function.dataType
Expand All @@ -153,8 +153,11 @@ case object SortNullsUnspecified extends NullOrdering
case object SortNullsFirst extends NullOrdering
case object SortNullsLast extends NullOrdering

case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering) extends Expression {
override def children: Seq[Expression] = child :: Nil
case class SortOrder(
child: Expression,
direction: SortDirection = AscendingSortDirection,
nullOrdering: NullOrdering = SortNullsUnspecified)
extends Unary(child) {
override def dataType: DataType = child.dataType
}

Expand Down Expand Up @@ -204,8 +207,8 @@ case class UpdateFields(struct_expression: Expression, field_name: String, value
override def dataType: DataType = UnresolvedType // TODO: Fix this
}

case class Alias(expr: Expression, name: Seq[Id], metadata: Option[String]) extends Expression {
override def children: Seq[Expression] = expr :: Nil
// TODO: has to be Alias(expr: Expression, name: String)
case class Alias(expr: Expression, name: Seq[Id], metadata: Option[String] = None) extends Unary(expr) {
override def dataType: DataType = expr.dataType
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,25 @@ case object NoopNode extends LeafNode {
override def output: Seq[Attribute] = Seq.empty
}

// TODO: (nfx) refactor to align more with catalyst
// TODO: (nfx) refactor to align more with catalyst, UnaryNode
// case class UnresolvedWith(child: LogicalPlan, ctes: Seq[(String, SubqueryAlias)])
case class WithCTE(ctes: Seq[LogicalPlan], query: LogicalPlan) extends RelationCommon {
override def output: Seq[Attribute] = query.output
override def children: Seq[LogicalPlan] = ctes :+ query
}

// TODO: (nfx) refactor to align more with catalyst
// TOTO: remove this class, replace with SubqueryAlias
case class CTEDefinition(tableName: String, columns: Seq[Expression], cte: LogicalPlan) extends RelationCommon {
override def output: Seq[Attribute] = columns.map(c => AttributeReference(c.toString, c.dataType))
override def children: Seq[LogicalPlan] = Seq(cte)
}

// TODO: (nfx) refactor to align more with catalyst
case class Star(objectName: Option[ObjectReference]) extends LeafExpression {
// TODO: (nfx) refactor to align more with catalyst, rename to UnresolvedStar
case class Star(objectName: Option[ObjectReference] = None) extends LeafExpression {
override def dataType: DataType = UnresolvedType
}

// TODO: (nfx) refactor to align more with catalyst
case class Exists(relation: LogicalPlan) extends ToRefactor

// Assignment operators
case class Assign(left: Expression, right: Expression) extends Binary(left, right) {
override def dataType: DataType = UnresolvedType
Expand Down Expand Up @@ -82,7 +81,7 @@ case class TableWithHints(child: LogicalPlan, hints: Seq[TableHint]) extends Una
}

case class Batch(children: Seq[LogicalPlan]) extends LogicalPlan {
override def output: Seq[Attribute] = Seq.empty
override def output: Seq[Attribute] = children.lastOption.map(_.output).getOrElse(Seq())
}

case class FunctionParameter(name: String, dataType: DataType, defaultValue: Option[Expression])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ case class SQL(query: String, named_arguments: Map[String, Expression], pos_argu

abstract class Read(is_streaming: Boolean) extends LeafNode

// TODO: replace most (if not all) occurrences with UnresolvedRelation
case class NamedTable(unparsed_identifier: String, options: Map[String, String], is_streaming: Boolean)
extends Read(is_streaming) {
override def output: Seq[Attribute] = Seq.empty
Expand All @@ -29,6 +30,7 @@ case class DataSource(

case class Project(input: LogicalPlan, override val expressions: Seq[Expression]) extends UnaryNode {
override def child: LogicalPlan = input
// TODO: add resolver for Star
override def output: Seq[Attribute] = expressions.map(_.asInstanceOf[Attribute])
}

Expand Down Expand Up @@ -70,11 +72,8 @@ case class SetOperation(
override def output: Seq[Attribute] = left.output ++ right.output
}

// TODO: move is_percentage / with_ties to TSQL-specific nodes
case class Limit(input: LogicalPlan, limit: Expression, is_percentage: Boolean = false, with_ties: Boolean = false)
extends UnaryNode {
override def child: LogicalPlan = input
override def output: Seq[Attribute] = input.output
case class Limit(child: LogicalPlan, limit: Expression) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

case class Offset(child: LogicalPlan, offset: Expression) extends UnaryNode {
Expand Down Expand Up @@ -141,7 +140,8 @@ case class Range(start: Long, end: Long, step: Long, num_partitions: Int) extend
override def output: Seq[Attribute] = Seq(AttributeReference("id", LongType))
}

case class SubqueryAlias(child: LogicalPlan, alias: Id, columnNames: Seq[Id]) extends UnaryNode {
// TODO: most likely has to be SubqueryAlias(identifier: AliasIdentifier, child: LogicalPlan)
case class SubqueryAlias(child: LogicalPlan, alias: Id, columnNames: Seq[Id] = Seq.empty) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.databricks.labs.remorph.parsers.intermediate

abstract class Rule[T <: TreeNode[_]] {
val ruleName: String = {
val className = getClass.getName
if (className endsWith "$") className.dropRight(1) else className
}

def apply(plan: T): T
}

case class Rules[T <: TreeNode[_]](rules: Rule[T]*) extends Rule[T] {
def apply(plan: T): T = {
rules.foldLeft(plan) { case (p, rule) => rule(p) }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.databricks.labs.remorph.parsers.intermediate

abstract class SubqueryExpression(plan: LogicalPlan) extends Expression {
override def children: Seq[Expression] = plan.expressions // TODO: not sure if this is a good idea
}

// returns one column. TBD if we want to split between
// one row (scala) and ListQuery (many rows), as it makes
// little difference for SQL code generation.
// scalar: SELECT * FROM a WHERE id = (SELECT id FROM b LIMIT 1)
// list: SELECT * FROM a WHERE id IN(SELECT id FROM b)
case class ScalarSubquery(relation: LogicalPlan) extends SubqueryExpression(relation) {
// TODO: we need to resolve schema of the plan
// before we get the type of this expression
override def dataType: DataType = UnresolvedType
}

// checks if a row exists in a subquery given some condition
case class Exists(relation: LogicalPlan) extends SubqueryExpression(relation) {
// TODO: we need to resolve schema of the plan
// before we get the type of this expression
override def dataType: DataType = UnresolvedType
}
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
*/
def simpleString(maxFields: Int): String = s"$nodeName ${argString(maxFields)}".trim

override def toString: String = treeString
override def toString: String = treeString.replaceAll("\n]\n", "]\n") // TODO: fix properly

/** Returns a string representation of the nodes in this tree */
final def treeString: String = treeString()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.databricks.labs.remorph.parsers.tsql

import com.databricks.labs.remorph.parsers.intermediate.ScalarSubquery
import com.databricks.labs.remorph.parsers.tsql.TSqlParser._
import com.databricks.labs.remorph.parsers.{ParserCommon, XmlFunction, tsql, intermediate => ir}
import org.antlr.v4.runtime.Token
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.databricks.labs.remorph.parsers.tsql

import com.databricks.labs.remorph.parsers.tsql.TSqlParser._
import com.databricks.labs.remorph.parsers.tsql.rules.TopPercent
import com.databricks.labs.remorph.parsers.{intermediate => ir}
import org.antlr.v4.runtime.ParserRuleContext

Expand Down Expand Up @@ -60,11 +61,12 @@ class TSqlRelationBuilder extends TSqlParserBaseVisitor[ir.LogicalPlan] {

private def buildTop(ctxOpt: Option[TSqlParser.TopClauseContext], input: ir.LogicalPlan): ir.LogicalPlan =
ctxOpt.fold(input) { top =>
ir.Limit(
input,
top.expression().accept(expressionBuilder),
is_percentage = top.PERCENT() != null,
with_ties = top.TIES() != null)
val limit = top.expression().accept(expressionBuilder)
if (top.PERCENT() != null) {
TopPercent(input, limit, with_ties = top.TIES() != null)
} else {
ir.Limit(input, limit)
}
}

override def visitSelectOptionalClauses(ctx: SelectOptionalClausesContext): ir.LogicalPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ case class Inserted(selection: Expression) extends Unary(selection) {
override def dataType: DataType = selection.dataType
}

case class ScalarSubquery(relation: LogicalPlan) extends ToRefactor

// The default case for the expression parser needs to be explicitly defined to distinguish [DEFAULT]
case class Default() extends LeafExpression {
override def dataType: DataType = UnresolvedType
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.databricks.labs.remorph.parsers.tsql.rules

import com.databricks.labs.remorph.parsers.intermediate._

// TSQL has "SELECT TOP N * FROM .." vs "SELECT * FROM .. LIMIT N", so we fix it here
object PullLimitUpwards extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Project(Limit(child, limit), exprs) =>
Limit(Project(child, exprs), limit)
case Filter(Limit(child, limit), cond) =>
Limit(Filter(child, cond), limit)
case Sort(Limit(child, limit), order, global) =>
Limit(Sort(child, order, global), limit)
case Offset(Limit(child, limit), offset) =>
Limit(Offset(child, offset), limit)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.databricks.labs.remorph.parsers.tsql.rules

import com.databricks.labs.remorph.parsers.intermediate._

import java.util.concurrent.atomic.AtomicLong

case class TopPercent(child: LogicalPlan, percentage: Expression, with_ties: Boolean = false) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

class TopPercentToLimitSubquery extends Rule[LogicalPlan] {
private val counter = new AtomicLong()
override def apply(plan: LogicalPlan): LogicalPlan = normalize(plan) transformUp {
case TopPercent(child, percentage, withTies) =>
if (withTies) {
withPercentiles(child, percentage)
} else {
viaTotalCount(child, percentage)
}
}

/** See [[PullLimitUpwards]] */
private def normalize(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Project(TopPercent(child, limit, withTies), exprs) =>
TopPercent(Project(child, exprs), limit, withTies)
case Filter(TopPercent(child, limit, withTies), cond) =>
TopPercent(Filter(child, cond), limit, withTies)
case Sort(TopPercent(child, limit, withTies), order, global) =>
TopPercent(Sort(child, order, global), limit, withTies)
case Offset(TopPercent(child, limit, withTies), offset) =>
TopPercent(Offset(child, offset), limit, withTies)
}

private def withPercentiles(child: LogicalPlan, percentage: Expression) = {
val cteSuffix = counter.incrementAndGet()
val originalCteName = s"_limited$cteSuffix"
val withPercentileCteName = s"_with_percentile$cteSuffix"
val percentileColName = s"_percentile$cteSuffix"
child match {
case Sort(child, order, _) =>
// this is (temporary) hack due to the lack of star resolution. otherwise child.output is fine
val reProject = child.find(_.isInstanceOf[Project]).map(_.asInstanceOf[Project]) match {
case Some(Project(_, expressions)) => expressions
case None =>
throw new IllegalArgumentException("Cannot find a projection")
}
WithCTE(
Seq(
SubqueryAlias(child, Id(originalCteName)),
SubqueryAlias(
Project(
UnresolvedRelation(originalCteName),
reProject ++ Seq(
Alias(Window(NTile(Literal(short = Some(100))), sort_order = order), Seq(Id(percentileColName))))),
Id(withPercentileCteName))),
Filter(
Project(UnresolvedRelation(withPercentileCteName), reProject),
LessThanOrEqual(UnresolvedAttribute(percentileColName), Divide(percentage, Literal(short = Some(100))))))
case _ =>
// TODO: (jimidle) figure out cases when this is not true
throw new IllegalArgumentException("TopPercent with ties requires a Sort node")
}
}

private def viaTotalCount(child: LogicalPlan, percentage: Expression) = {
val cteSuffix = counter.incrementAndGet()
val originalCteName = s"_limited$cteSuffix"
val countedCteName = s"_counted$cteSuffix"
WithCTE(
Seq(
SubqueryAlias(child, Id(originalCteName)),
SubqueryAlias(
Project(UnresolvedRelation(originalCteName), Seq(Alias(Count(Seq(Star())), Seq(Id("count"))))),
Id(countedCteName))),
Limit(
Project(UnresolvedRelation(originalCteName), Seq(Star())),
ScalarSubquery(
Project(
UnresolvedRelation(countedCteName),
Seq(Cast(Multiply(Divide(Id("count"), percentage), Literal(short = Some(100))), LongType))))))
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package com.databricks.labs.remorph.parsers

import com.databricks.labs.remorph.parsers.{intermediate => ir}
import com.databricks.labs.remorph.utils.Strings
import org.antlr.v4.runtime._
import org.antlr.v4.runtime.tree.ParseTreeVisitor
import org.scalatest.{Assertion, Assertions}

trait ParserTestCommon[P <: Parser] { self: Assertions =>
trait ParserTestCommon[P <: Parser] extends PlanComparison { self: Assertions =>

protected def makeLexer(chars: CharStream): TokenSource
protected def makeParser(tokens: TokenStream): P
Expand All @@ -29,27 +28,6 @@ trait ParserTestCommon[P <: Parser] { self: Assertions =>
tree
}

protected def comparePlans(a: ir.LogicalPlan, b: ir.LogicalPlan): Unit = {
val expected = reorderComparisons(a)
val actual = reorderComparisons(b)
if (expected != actual) {
fail(s"""
|== FAIL: Plans do not match ===
|${Strings.sideBySide(expected.treeString, actual.treeString).mkString("\n")}
""".stripMargin)
}
}

protected def reorderComparisons(plan: ir.LogicalPlan): ir.LogicalPlan = {
plan transformAllExpressions {
case ir.Equals(l, r) if l.hashCode() > r.hashCode() => ir.Equals(r, l)
case ir.GreaterThan(l, r) if l.hashCode() > r.hashCode() => ir.LessThan(r, l)
case ir.GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => ir.LessThanOrEqual(r, l)
case ir.LessThan(l, r) if l.hashCode() > r.hashCode() => ir.GreaterThan(r, l)
case ir.LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => ir.GreaterThanOrEqual(r, l)
}
}

protected def example[R <: RuleContext](query: String, rule: P => R, expectedAst: ir.LogicalPlan) = {
val sfTree = parseString(query, rule)
if (errHandler != null && errHandler.errorCount != 0) {
Expand Down
Loading

0 comments on commit e17e64b

Please sign in to comment.