Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed SELECT TOP X PERCENT IR translation for TSQL #733

Merged
merged 4 commits into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,12 @@
<version>${mockito.version}</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.lihaoyi</groupId>
<artifactId>pprint_${scala.binary.version}</artifactId>
<version>0.8.1</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>com.typesafe.scala-logging</groupId>
<artifactId>scala-logging_${scala.binary.version}</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package com.databricks.labs.remorph.generators.sql
import com.databricks.labs.remorph.parsers.{intermediate => ir}
import com.databricks.labs.remorph.generators.{Generator, GeneratorContext}
import com.databricks.labs.remorph.parsers.intermediate.{ExceptSetOp, IntersectSetOp, UnionSetOp}
import com.databricks.labs.remorph.transpilers.TranspileException

class LogicalPlanGenerator(val explicitDistinct: Boolean = false) extends Generator[ir.LogicalPlan, String] {

Expand All @@ -18,10 +17,7 @@ class LogicalPlanGenerator(val explicitDistinct: Boolean = false) extends Genera
case ir.NamedTable(id, _, _) => id
case ir.Filter(input, condition) =>
s"${generate(ctx, input)} WHERE ${expr.generate(ctx, condition)}"
case ir.Limit(input, limit, percentage, _) =>
if (percentage) {
throw TranspileException("SELECT TOP .. PERCENT has to be transformed")
}
case ir.Limit(input, limit) =>
s"${generate(ctx, input)} LIMIT ${expr.generate(ctx, limit)}"
case join: ir.Join => generateJoin(ctx, join)
case setOp: ir.SetOperation => setOperation(ctx, setOp)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,9 @@ case class WindowFrame(frame_type: FrameType, lower: FrameBoundary, upper: Frame

case class Window(
window_function: Expression,
partition_spec: Seq[Expression],
sort_order: Seq[SortOrder],
frame_spec: Option[WindowFrame])
partition_spec: Seq[Expression] = Seq.empty,
sort_order: Seq[SortOrder] = Seq.empty,
frame_spec: Option[WindowFrame] = None)
extends Expression {
override def children: Seq[Expression] = window_function +: partition_spec
override def dataType: DataType = window_function.dataType
Expand All @@ -153,8 +153,11 @@ case object SortNullsUnspecified extends NullOrdering
case object SortNullsFirst extends NullOrdering
case object SortNullsLast extends NullOrdering

case class SortOrder(child: Expression, direction: SortDirection, nullOrdering: NullOrdering) extends Expression {
override def children: Seq[Expression] = child :: Nil
case class SortOrder(
child: Expression,
direction: SortDirection = AscendingSortDirection,
nullOrdering: NullOrdering = SortNullsUnspecified)
extends Unary(child) {
override def dataType: DataType = child.dataType
}

Expand Down Expand Up @@ -204,8 +207,8 @@ case class UpdateFields(struct_expression: Expression, field_name: String, value
override def dataType: DataType = UnresolvedType // TODO: Fix this
}

case class Alias(expr: Expression, name: Seq[Id], metadata: Option[String]) extends Expression {
override def children: Seq[Expression] = expr :: Nil
// TODO: has to be Alias(expr: Expression, name: String)
case class Alias(expr: Expression, name: Seq[Id], metadata: Option[String] = None) extends Unary(expr) {
override def dataType: DataType = expr.dataType
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,25 @@ case object NoopNode extends LeafNode {
override def output: Seq[Attribute] = Seq.empty
}

// TODO: (nfx) refactor to align more with catalyst
// TODO: (nfx) refactor to align more with catalyst, UnaryNode
// case class UnresolvedWith(child: LogicalPlan, ctes: Seq[(String, SubqueryAlias)])
case class WithCTE(ctes: Seq[LogicalPlan], query: LogicalPlan) extends RelationCommon {
override def output: Seq[Attribute] = query.output
override def children: Seq[LogicalPlan] = ctes :+ query
}

// TODO: (nfx) refactor to align more with catalyst
// TOTO: remove this class, replace with SubqueryAlias
case class CTEDefinition(tableName: String, columns: Seq[Expression], cte: LogicalPlan) extends RelationCommon {
override def output: Seq[Attribute] = columns.map(c => AttributeReference(c.toString, c.dataType))
override def children: Seq[LogicalPlan] = Seq(cte)
}

// TODO: (nfx) refactor to align more with catalyst
case class Star(objectName: Option[ObjectReference]) extends LeafExpression {
// TODO: (nfx) refactor to align more with catalyst, rename to UnresolvedStar
case class Star(objectName: Option[ObjectReference] = None) extends LeafExpression {
override def dataType: DataType = UnresolvedType
}

// TODO: (nfx) refactor to align more with catalyst
case class Exists(relation: LogicalPlan) extends ToRefactor

// Assignment operators
case class Assign(left: Expression, right: Expression) extends Binary(left, right) {
override def dataType: DataType = UnresolvedType
Expand Down Expand Up @@ -82,7 +81,7 @@ case class TableWithHints(child: LogicalPlan, hints: Seq[TableHint]) extends Una
}

case class Batch(children: Seq[LogicalPlan]) extends LogicalPlan {
override def output: Seq[Attribute] = Seq.empty
override def output: Seq[Attribute] = children.lastOption.map(_.output).getOrElse(Seq())
}

case class FunctionParameter(name: String, dataType: DataType, defaultValue: Option[Expression])
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ case class SQL(query: String, named_arguments: Map[String, Expression], pos_argu

abstract class Read(is_streaming: Boolean) extends LeafNode

// TODO: replace most (if not all) occurrences with UnresolvedRelation
case class NamedTable(unparsed_identifier: String, options: Map[String, String], is_streaming: Boolean)
extends Read(is_streaming) {
override def output: Seq[Attribute] = Seq.empty
Expand All @@ -29,6 +30,7 @@ case class DataSource(

case class Project(input: LogicalPlan, override val expressions: Seq[Expression]) extends UnaryNode {
override def child: LogicalPlan = input
// TODO: add resolver for Star
override def output: Seq[Attribute] = expressions.map(_.asInstanceOf[Attribute])
}

Expand Down Expand Up @@ -70,11 +72,8 @@ case class SetOperation(
override def output: Seq[Attribute] = left.output ++ right.output
}

// TODO: move is_percentage / with_ties to TSQL-specific nodes
case class Limit(input: LogicalPlan, limit: Expression, is_percentage: Boolean = false, with_ties: Boolean = false)
extends UnaryNode {
override def child: LogicalPlan = input
override def output: Seq[Attribute] = input.output
case class Limit(child: LogicalPlan, limit: Expression) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

case class Offset(child: LogicalPlan, offset: Expression) extends UnaryNode {
Expand Down Expand Up @@ -141,7 +140,8 @@ case class Range(start: Long, end: Long, step: Long, num_partitions: Int) extend
override def output: Seq[Attribute] = Seq(AttributeReference("id", LongType))
}

case class SubqueryAlias(child: LogicalPlan, alias: Id, columnNames: Seq[Id]) extends UnaryNode {
// TODO: most likely has to be SubqueryAlias(identifier: AliasIdentifier, child: LogicalPlan)
case class SubqueryAlias(child: LogicalPlan, alias: Id, columnNames: Seq[Id] = Seq.empty) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
package com.databricks.labs.remorph.parsers.intermediate

abstract class Rule[T <: TreeNode[_]] {
val ruleName: String = {
val className = getClass.getName
if (className endsWith "$") className.dropRight(1) else className
}

def apply(plan: T): T
}

case class Rules[T <: TreeNode[_]](rules: Rule[T]*) extends Rule[T] {
def apply(plan: T): T = {
rules.foldLeft(plan) { case (p, rule) => rule(p) }
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package com.databricks.labs.remorph.parsers.intermediate

abstract class SubqueryExpression(plan: LogicalPlan) extends Expression {
override def children: Seq[Expression] = plan.expressions // TODO: not sure if this is a good idea
}

// returns one column. TBD if we want to split between
// one row (scala) and ListQuery (many rows), as it makes
// little difference for SQL code generation.
// scalar: SELECT * FROM a WHERE id = (SELECT id FROM b LIMIT 1)
// list: SELECT * FROM a WHERE id IN(SELECT id FROM b)
case class ScalarSubquery(relation: LogicalPlan) extends SubqueryExpression(relation) {
// TODO: we need to resolve schema of the plan
// before we get the type of this expression
override def dataType: DataType = UnresolvedType
}

// checks if a row exists in a subquery given some condition
case class Exists(relation: LogicalPlan) extends SubqueryExpression(relation) {
// TODO: we need to resolve schema of the plan
// before we get the type of this expression
override def dataType: DataType = UnresolvedType
}
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
*/
def simpleString(maxFields: Int): String = s"$nodeName ${argString(maxFields)}".trim

override def toString: String = treeString
override def toString: String = treeString.replaceAll("\n]\n", "]\n") // TODO: fix properly

/** Returns a string representation of the nodes in this tree */
final def treeString: String = treeString()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package com.databricks.labs.remorph.parsers.tsql

import com.databricks.labs.remorph.parsers.intermediate.ScalarSubquery
import com.databricks.labs.remorph.parsers.tsql.TSqlParser._
import com.databricks.labs.remorph.parsers.{ParserCommon, XmlFunction, tsql, intermediate => ir}
import org.antlr.v4.runtime.Token
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.databricks.labs.remorph.parsers.tsql

import com.databricks.labs.remorph.parsers.tsql.TSqlParser._
import com.databricks.labs.remorph.parsers.tsql.rules.TopPercent
import com.databricks.labs.remorph.parsers.{intermediate => ir}
import org.antlr.v4.runtime.ParserRuleContext

Expand Down Expand Up @@ -60,11 +61,12 @@ class TSqlRelationBuilder extends TSqlParserBaseVisitor[ir.LogicalPlan] {

private def buildTop(ctxOpt: Option[TSqlParser.TopClauseContext], input: ir.LogicalPlan): ir.LogicalPlan =
ctxOpt.fold(input) { top =>
ir.Limit(
input,
top.expression().accept(expressionBuilder),
is_percentage = top.PERCENT() != null,
with_ties = top.TIES() != null)
val limit = top.expression().accept(expressionBuilder)
if (top.PERCENT() != null) {
TopPercent(input, limit, with_ties = top.TIES() != null)
} else {
ir.Limit(input, limit)
}
}

override def visitSelectOptionalClauses(ctx: SelectOptionalClausesContext): ir.LogicalPlan = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ case class Inserted(selection: Expression) extends Unary(selection) {
override def dataType: DataType = selection.dataType
}

case class ScalarSubquery(relation: LogicalPlan) extends ToRefactor

// The default case for the expression parser needs to be explicitly defined to distinguish [DEFAULT]
case class Default() extends LeafExpression {
override def dataType: DataType = UnresolvedType
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package com.databricks.labs.remorph.parsers.tsql.rules

import com.databricks.labs.remorph.parsers.intermediate._

// TSQL has "SELECT TOP N * FROM .." vs "SELECT * FROM .. LIMIT N", so we fix it here
object PullLimitUpwards extends Rule[LogicalPlan] {
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Project(Limit(child, limit), exprs) =>
Limit(Project(child, exprs), limit)
case Filter(Limit(child, limit), cond) =>
Limit(Filter(child, cond), limit)
case Sort(Limit(child, limit), order, global) =>
Limit(Sort(child, order, global), limit)
case Offset(Limit(child, limit), offset) =>
Limit(Offset(child, offset), limit)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package com.databricks.labs.remorph.parsers.tsql.rules

import com.databricks.labs.remorph.parsers.intermediate._

import java.util.concurrent.atomic.AtomicLong

case class TopPercent(child: LogicalPlan, percentage: Expression, with_ties: Boolean = false) extends UnaryNode {
override def output: Seq[Attribute] = child.output
}

class TopPercentToLimitSubquery extends Rule[LogicalPlan] {
private val counter = new AtomicLong()
override def apply(plan: LogicalPlan): LogicalPlan = normalize(plan) transformUp {
case TopPercent(child, percentage, withTies) =>
if (withTies) {
withPercentiles(child, percentage)
} else {
viaTotalCount(child, percentage)
}
}

/** See [[PullLimitUpwards]] */
private def normalize(plan: LogicalPlan): LogicalPlan = plan transformUp {
case Project(TopPercent(child, limit, withTies), exprs) =>
TopPercent(Project(child, exprs), limit, withTies)
case Filter(TopPercent(child, limit, withTies), cond) =>
TopPercent(Filter(child, cond), limit, withTies)
case Sort(TopPercent(child, limit, withTies), order, global) =>
TopPercent(Sort(child, order, global), limit, withTies)
case Offset(TopPercent(child, limit, withTies), offset) =>
TopPercent(Offset(child, offset), limit, withTies)
}

private def withPercentiles(child: LogicalPlan, percentage: Expression) = {
val cteSuffix = counter.incrementAndGet()
val originalCteName = s"_limited$cteSuffix"
val withPercentileCteName = s"_with_percentile$cteSuffix"
val percentileColName = s"_percentile$cteSuffix"
child match {
case Sort(child, order, _) =>
// this is (temporary) hack due to the lack of star resolution. otherwise child.output is fine
val reProject = child.find(_.isInstanceOf[Project]).map(_.asInstanceOf[Project]) match {
case Some(Project(_, expressions)) => expressions
case None =>
throw new IllegalArgumentException("Cannot find a projection")
}
WithCTE(
Seq(
SubqueryAlias(child, Id(originalCteName)),
SubqueryAlias(
Project(
UnresolvedRelation(originalCteName),
reProject ++ Seq(
Alias(Window(NTile(Literal(short = Some(100))), sort_order = order), Seq(Id(percentileColName))))),
Id(withPercentileCteName))),
Filter(
Project(UnresolvedRelation(withPercentileCteName), reProject),
LessThanOrEqual(UnresolvedAttribute(percentileColName), Divide(percentage, Literal(short = Some(100))))))
case _ =>
// TODO: (jimidle) figure out cases when this is not true
throw new IllegalArgumentException("TopPercent with ties requires a Sort node")
}
}

private def viaTotalCount(child: LogicalPlan, percentage: Expression) = {
val cteSuffix = counter.incrementAndGet()
val originalCteName = s"_limited$cteSuffix"
val countedCteName = s"_counted$cteSuffix"
WithCTE(
Seq(
SubqueryAlias(child, Id(originalCteName)),
SubqueryAlias(
Project(UnresolvedRelation(originalCteName), Seq(Alias(Count(Seq(Star())), Seq(Id("count"))))),
Id(countedCteName))),
Limit(
Project(UnresolvedRelation(originalCteName), Seq(Star())),
ScalarSubquery(
Project(
UnresolvedRelation(countedCteName),
Seq(Cast(Multiply(Divide(Id("count"), percentage), Literal(short = Some(100))), LongType))))))
}
}
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
package com.databricks.labs.remorph.parsers

import com.databricks.labs.remorph.parsers.{intermediate => ir}
import com.databricks.labs.remorph.utils.Strings
import org.antlr.v4.runtime._
import org.antlr.v4.runtime.tree.ParseTreeVisitor
import org.scalatest.{Assertion, Assertions}

trait ParserTestCommon[P <: Parser] { self: Assertions =>
trait ParserTestCommon[P <: Parser] extends PlanComparison { self: Assertions =>

protected def makeLexer(chars: CharStream): TokenSource
protected def makeParser(tokens: TokenStream): P
Expand All @@ -29,27 +28,6 @@ trait ParserTestCommon[P <: Parser] { self: Assertions =>
tree
}

protected def comparePlans(a: ir.LogicalPlan, b: ir.LogicalPlan): Unit = {
val expected = reorderComparisons(a)
val actual = reorderComparisons(b)
if (expected != actual) {
fail(s"""
|== FAIL: Plans do not match ===
|${Strings.sideBySide(expected.treeString, actual.treeString).mkString("\n")}
""".stripMargin)
}
}

protected def reorderComparisons(plan: ir.LogicalPlan): ir.LogicalPlan = {
plan transformAllExpressions {
case ir.Equals(l, r) if l.hashCode() > r.hashCode() => ir.Equals(r, l)
case ir.GreaterThan(l, r) if l.hashCode() > r.hashCode() => ir.LessThan(r, l)
case ir.GreaterThanOrEqual(l, r) if l.hashCode() > r.hashCode() => ir.LessThanOrEqual(r, l)
case ir.LessThan(l, r) if l.hashCode() > r.hashCode() => ir.GreaterThan(r, l)
case ir.LessThanOrEqual(l, r) if l.hashCode() > r.hashCode() => ir.GreaterThanOrEqual(r, l)
}
}

protected def example[R <: RuleContext](query: String, rule: P => R, expectedAst: ir.LogicalPlan) = {
val sfTree = parseString(query, rule)
if (errHandler != null && errHandler.errorCount != 0) {
Expand Down
Loading
Loading