-
Notifications
You must be signed in to change notification settings - Fork 36
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Fixed
SELECT TOP X PERCENT
IR translation for TSQL (#733)
TSQL has `SELECT TOP N * FROM ..` vs `SELECT * FROM .. LIMIT N`, so we fix it here
- Loading branch information
Showing
18 changed files
with
331 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
16 changes: 16 additions & 0 deletions
16
core/src/main/scala/com/databricks/labs/remorph/parsers/intermediate/rules.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
package com.databricks.labs.remorph.parsers.intermediate | ||
|
||
abstract class Rule[T <: TreeNode[_]] { | ||
val ruleName: String = { | ||
val className = getClass.getName | ||
if (className endsWith "$") className.dropRight(1) else className | ||
} | ||
|
||
def apply(plan: T): T | ||
} | ||
|
||
case class Rules[T <: TreeNode[_]](rules: Rule[T]*) extends Rule[T] { | ||
def apply(plan: T): T = { | ||
rules.foldLeft(plan) { case (p, rule) => rule(p) } | ||
} | ||
} |
23 changes: 23 additions & 0 deletions
23
core/src/main/scala/com/databricks/labs/remorph/parsers/intermediate/subqueries.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
package com.databricks.labs.remorph.parsers.intermediate | ||
|
||
abstract class SubqueryExpression(plan: LogicalPlan) extends Expression { | ||
override def children: Seq[Expression] = plan.expressions // TODO: not sure if this is a good idea | ||
} | ||
|
||
// returns one column. TBD if we want to split between | ||
// one row (scala) and ListQuery (many rows), as it makes | ||
// little difference for SQL code generation. | ||
// scalar: SELECT * FROM a WHERE id = (SELECT id FROM b LIMIT 1) | ||
// list: SELECT * FROM a WHERE id IN(SELECT id FROM b) | ||
case class ScalarSubquery(relation: LogicalPlan) extends SubqueryExpression(relation) { | ||
// TODO: we need to resolve schema of the plan | ||
// before we get the type of this expression | ||
override def dataType: DataType = UnresolvedType | ||
} | ||
|
||
// checks if a row exists in a subquery given some condition | ||
case class Exists(relation: LogicalPlan) extends SubqueryExpression(relation) { | ||
// TODO: we need to resolve schema of the plan | ||
// before we get the type of this expression | ||
override def dataType: DataType = UnresolvedType | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
1 change: 1 addition & 0 deletions
1
core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/TSqlExpressionBuilder.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
17 changes: 17 additions & 0 deletions
17
core/src/main/scala/com/databricks/labs/remorph/parsers/tsql/rules/PullLimitUpwards.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package com.databricks.labs.remorph.parsers.tsql.rules | ||
|
||
import com.databricks.labs.remorph.parsers.intermediate._ | ||
|
||
// TSQL has "SELECT TOP N * FROM .." vs "SELECT * FROM .. LIMIT N", so we fix it here | ||
object PullLimitUpwards extends Rule[LogicalPlan] { | ||
override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { | ||
case Project(Limit(child, limit), exprs) => | ||
Limit(Project(child, exprs), limit) | ||
case Filter(Limit(child, limit), cond) => | ||
Limit(Filter(child, cond), limit) | ||
case Sort(Limit(child, limit), order, global) => | ||
Limit(Sort(child, order, global), limit) | ||
case Offset(Limit(child, limit), offset) => | ||
Limit(Offset(child, offset), limit) | ||
} | ||
} |
82 changes: 82 additions & 0 deletions
82
...main/scala/com/databricks/labs/remorph/parsers/tsql/rules/TopPercentToLimitSubquery.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
package com.databricks.labs.remorph.parsers.tsql.rules | ||
|
||
import com.databricks.labs.remorph.parsers.intermediate._ | ||
|
||
import java.util.concurrent.atomic.AtomicLong | ||
|
||
case class TopPercent(child: LogicalPlan, percentage: Expression, with_ties: Boolean = false) extends UnaryNode { | ||
override def output: Seq[Attribute] = child.output | ||
} | ||
|
||
class TopPercentToLimitSubquery extends Rule[LogicalPlan] { | ||
private val counter = new AtomicLong() | ||
override def apply(plan: LogicalPlan): LogicalPlan = normalize(plan) transformUp { | ||
case TopPercent(child, percentage, withTies) => | ||
if (withTies) { | ||
withPercentiles(child, percentage) | ||
} else { | ||
viaTotalCount(child, percentage) | ||
} | ||
} | ||
|
||
/** See [[PullLimitUpwards]] */ | ||
private def normalize(plan: LogicalPlan): LogicalPlan = plan transformUp { | ||
case Project(TopPercent(child, limit, withTies), exprs) => | ||
TopPercent(Project(child, exprs), limit, withTies) | ||
case Filter(TopPercent(child, limit, withTies), cond) => | ||
TopPercent(Filter(child, cond), limit, withTies) | ||
case Sort(TopPercent(child, limit, withTies), order, global) => | ||
TopPercent(Sort(child, order, global), limit, withTies) | ||
case Offset(TopPercent(child, limit, withTies), offset) => | ||
TopPercent(Offset(child, offset), limit, withTies) | ||
} | ||
|
||
private def withPercentiles(child: LogicalPlan, percentage: Expression) = { | ||
val cteSuffix = counter.incrementAndGet() | ||
val originalCteName = s"_limited$cteSuffix" | ||
val withPercentileCteName = s"_with_percentile$cteSuffix" | ||
val percentileColName = s"_percentile$cteSuffix" | ||
child match { | ||
case Sort(child, order, _) => | ||
// this is (temporary) hack due to the lack of star resolution. otherwise child.output is fine | ||
val reProject = child.find(_.isInstanceOf[Project]).map(_.asInstanceOf[Project]) match { | ||
case Some(Project(_, expressions)) => expressions | ||
case None => | ||
throw new IllegalArgumentException("Cannot find a projection") | ||
} | ||
WithCTE( | ||
Seq( | ||
SubqueryAlias(child, Id(originalCteName)), | ||
SubqueryAlias( | ||
Project( | ||
UnresolvedRelation(originalCteName), | ||
reProject ++ Seq( | ||
Alias(Window(NTile(Literal(short = Some(100))), sort_order = order), Seq(Id(percentileColName))))), | ||
Id(withPercentileCteName))), | ||
Filter( | ||
Project(UnresolvedRelation(withPercentileCteName), reProject), | ||
LessThanOrEqual(UnresolvedAttribute(percentileColName), Divide(percentage, Literal(short = Some(100)))))) | ||
case _ => | ||
// TODO: (jimidle) figure out cases when this is not true | ||
throw new IllegalArgumentException("TopPercent with ties requires a Sort node") | ||
} | ||
} | ||
|
||
private def viaTotalCount(child: LogicalPlan, percentage: Expression) = { | ||
val cteSuffix = counter.incrementAndGet() | ||
val originalCteName = s"_limited$cteSuffix" | ||
val countedCteName = s"_counted$cteSuffix" | ||
WithCTE( | ||
Seq( | ||
SubqueryAlias(child, Id(originalCteName)), | ||
SubqueryAlias( | ||
Project(UnresolvedRelation(originalCteName), Seq(Alias(Count(Seq(Star())), Seq(Id("count"))))), | ||
Id(countedCteName))), | ||
Limit( | ||
Project(UnresolvedRelation(originalCteName), Seq(Star())), | ||
ScalarSubquery( | ||
Project( | ||
UnresolvedRelation(countedCteName), | ||
Seq(Cast(Multiply(Divide(Id("count"), percentage), Literal(short = Some(100))), LongType)))))) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.