Skip to content

Commit

Permalink
feat: Add full support for analytical windowing functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jimidle committed May 30, 2024
1 parent 027340f commit 0f10799
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,6 @@ CREATION_DISPOSITION : 'CREATION_DISPOSITION';
CREDENTIAL : 'CREDENTIAL';
CROSS : 'CROSS';
CRYPTOGRAPHIC : 'CRYPTOGRAPHIC';
CUME_DIST : 'CUME_DIST';
CURRENT : 'CURRENT';
CURRENT_TIME : 'CURRENT_TIME';
CURSOR : 'CURSOR';
Expand Down Expand Up @@ -338,7 +337,6 @@ FILE_SNAPSHOT : 'FILE_SNAPSHOT';
FILLFACTOR : 'FILLFACTOR';
FILTER : 'FILTER';
FIRST : 'FIRST';
FIRST_VALUE : 'FIRST_VALUE';
FMTONLY : 'FMTONLY';
FOLLOWING : 'FOLLOWING';
FOR : 'FOR';
Expand Down Expand Up @@ -440,11 +438,8 @@ KEY_PATH : 'KEY_PATH';
KEY_SOURCE : 'KEY_SOURCE';
KEY_STORE_PROVIDER_NAME : 'KEY_STORE_PROVIDER_NAME';
KILL : 'KILL';
LAG : 'LAG';
LANGUAGE : 'LANGUAGE';
LAST : 'LAST';
LAST_VALUE : 'LAST_VALUE';
LEAD : 'LEAD';
LEFT : 'LEFT';
LEVEL : 'LEVEL';
LIBRARY : 'LIBRARY';
Expand Down Expand Up @@ -618,9 +613,6 @@ PATH : 'PATH';
PAUSE : 'PAUSE';
PDW_SHOWSPACEUSED : 'PDW_SHOWSPACEUSED';
PERCENT : 'PERCENT';
PERCENTILE_CONT : 'PERCENTILE_CONT';
PERCENTILE_DISC : 'PERCENTILE_DISC';
PERCENT_RANK : 'PERCENT_RANK';
PERMISSION_SET : 'PERMISSION_SET';
PERSISTED : 'PERSISTED';
PERSIST_SAMPLE_PERCENT : 'PERSIST_SAMPLE_PERCENT';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3904,6 +3904,7 @@ expression
| caseExpression #exprCase
| expression timeZone #exprTz
| expression overClause #exprOver
| expression withinGroup #exprWithinGroup
| id #exprId
| DOLLAR_ACTION #exprDollar
| <assoc=right> expression DOT expression #exprDot
Expand Down Expand Up @@ -4282,14 +4283,11 @@ derivedTable
;

functionCall
: analyticWindowedFunction
| builtInFunctions
: builtInFunctions
| standardFunction
| freetextFunction
| partitionFunction
| hierarchyidStaticMethod
// TODO: This is broken and highly ambiguous - will need to be reworked so the expression allows the primitives
// | scalarFunctionName LPAREN expressionList? RPAREN
;

// Standard functions are built in but take standard syntax, or are
Expand Down Expand Up @@ -4485,13 +4483,8 @@ expressionList
;

// https://docs.microsoft.com/en-us/sql/t-sql/functions/analytic-functions-transact-sql
analyticWindowedFunction
: (FIRST_VALUE | LAST_VALUE) LPAREN expression RPAREN
| (LAG | LEAD) LPAREN expression (COMMA expression (COMMA expression)?)? RPAREN
| (CUME_DIST | PERCENT_RANK) LPAREN RPAREN OVER LPAREN (PARTITION BY expressionList)? orderByClause RPAREN
| (PERCENTILE_CONT | PERCENTILE_DISC) LPAREN expression RPAREN WITHIN GROUP LPAREN orderByClause RPAREN OVER LPAREN (
PARTITION BY expressionList
)? RPAREN
withinGroup
: WITHIN GROUP LPAREN orderByClause RPAREN
;

// https://msdn.microsoft.com/en-us/library/ms189461.aspx
Expand Down Expand Up @@ -4638,6 +4631,7 @@ nullNotnull
: NOT? NULL_
;

// TODO: Get rid of this after checking
scalarFunctionName
: funcProcNameServerDatabaseSchema
| RIGHT
Expand Down Expand Up @@ -4827,7 +4821,6 @@ keyword
| CREATION_DISPOSITION
| CREDENTIAL
| CRYPTOGRAPHIC
| CUME_DIST
| CURSOR_CLOSE_ON_COMMIT
| CURSOR_DEFAULT
| DATA
Expand Down Expand Up @@ -4892,7 +4885,6 @@ keyword
| FILESTREAM
| FILTER
| FIRST
| FIRST_VALUE
| FMTONLY
| FOLLOWING
| FORCE
Expand Down Expand Up @@ -4949,10 +4941,7 @@ keyword
| KEY_SOURCE
| KEYS
| KEYSET
| LAG
| LAST
| LAST_VALUE
| LEAD
| LEVEL
| LIST
| LISTENER
Expand Down Expand Up @@ -5052,9 +5041,6 @@ keyword
| PATH
| PAUSE
| PDW_SHOWSPACEUSED
| PERCENT_RANK
| PERCENTILE_CONT
| PERCENTILE_DISC
| PERSIST_SAMPLE_PERCENT
| PHYSICAL_ONLY
| POISON_MESSAGE_HANDLING
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ object FunctionBuilder {
case "COT" => Some(FixedArity(1))
case "COUNT" => Some(FixedArity(1))
case "COUNT_BIG" => Some(FixedArity(1))
case "CUME_DIST" => Some(FixedArity(0))
case "CURRENT_DATE" => Some(FixedArity(0))
case "CURRENT_REQUEST_ID" => Some(FixedArity(0))
case "CURRENT_TIMESTAMP" => Some(FixedArity(0))
Expand Down Expand Up @@ -109,6 +110,7 @@ object FunctionBuilder {
case "FILEGROUPPROPERTY" => Some(FixedArity(2))
case "FILEPROPERTY" => Some(FixedArity(2))
case "FILEPROPERTYEX" => Some(FixedArity(2))
case "FIRST_VALUE" => Some(FixedArity(1))
case "FLOOR" => Some(FixedArity(1))
case "FORMAT" => Some(VariableArity(2, 3))
case "FORMATMESSAGE" => Some(VariableArity(2, Int.MaxValue))
Expand Down Expand Up @@ -148,6 +150,9 @@ object FunctionBuilder {
case "JSON_PATH_EXISTS" => Some(FixedArity(2))
case "JSON_QUERY" => Some(FixedArity(2))
case "JSON_VALUE" => Some(FixedArity(2))
case "LAG" => Some(VariableArity(1, 3))
case "LAST_VALUE" => Some(FixedArity(1))
case "LEAD" => Some(VariableArity(1, 3))
case "LEAST" => Some(VariableArity(1, Int.MaxValue))
case "LEFT" => Some(FixedArity(2))
case "LEN" => Some(FixedArity(1))
Expand Down Expand Up @@ -178,6 +183,9 @@ object FunctionBuilder {
case "PARSE" => Some(VariableArity(2, 3, convertible = false)) // Not in DBSQL
case "PARSENAME" => Some(FixedArity(2))
case "PATINDEX" => Some(FixedArity(2))
case "PERCENT_RANK" => Some(FixedArity(0))
case "PERCENTILE_CONT" => Some(FixedArity(1))
case "PERCENTILE_DISC" => Some(FixedArity(1))
case "PERMISSIONS" => Some(VariableArity(0, 2, convertible = false)) // not in DBSQL
case "PI" => Some(FixedArity(0))
case "POWER" => Some(FixedArity(2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,5 @@ case class ScalarSubquery(relation: Relation) extends Expression {}
case class Timezone(expression: Expression, timeZone: Expression) extends Expression {}

case class Money(value: Literal) extends Expression {}

case class WithinGroup(expression: Expression, order: Seq[SortOrder]) extends Expression {}
Original file line number Diff line number Diff line change
Expand Up @@ -293,22 +293,24 @@ class TSqlExpressionBuilder extends TSqlParserBaseVisitor[ir.Expression] with Pa
val partitionByExpressions =
Option(ctx.overClause().expression()).map(_.asScala.toList.map(_.accept(this))).getOrElse(List.empty)
val orderByExpressions = Option(ctx.overClause().orderByClause())
.map(_.orderByExpression().asScala.toList.map { orderByExpr =>
val expression = orderByExpr.expression().accept(this)
val sortOrder =
if (Option(orderByExpr.DESC()).isDefined) ir.DescendingSortDirection
else ir.AscendingSortDirection
ir.SortOrder(expression, sortOrder, ir.SortNullsUnspecified)
})
.map(buildOrderBy)
.getOrElse(List.empty)

val rowRange = Option(ctx.overClause().rowOrRangeClause())
.map(buildWindowFrame)
.getOrElse(noWindowFrame)

ir.Window(windowFunction, partitionByExpressions, orderByExpressions, rowRange)
}

private def buildOrderBy(ctx: OrderByClauseContext): Seq[ir.SortOrder] =
ctx.orderByExpression().asScala.map { orderByExpr =>
val expression = orderByExpr.expression().accept(this)
val sortOrder =
if (Option(orderByExpr.DESC()).isDefined) ir.DescendingSortDirection
else ir.AscendingSortDirection
ir.SortOrder(expression, sortOrder, ir.SortNullsUnspecified)
}

private def noWindowFrame: ir.WindowFrame =
ir.WindowFrame(
ir.UndefinedFrame,
Expand Down Expand Up @@ -363,6 +365,12 @@ class TSqlExpressionBuilder extends TSqlParserBaseVisitor[ir.Expression] with Pa
}
}

override def visitExprWithinGroup(ctx: ExprWithinGroupContext): ir.Expression = {
val expression = ctx.expression().accept(this)
val orderByExpressions = buildOrderBy(ctx.withinGroup().orderByClause())
ir.WithinGroup(expression, orderByExpressions)
}

override def visitExprDistinct(ctx: ExprDistinctContext): ir.Expression = {
// Support for functions such as COUNT(DISTINCT column), which is an expression not a relation
ir.Distinct(ctx.expression().accept(this))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.databricks.labs.remorph.parsers.common

import com.databricks.labs.remorph.parsers.intermediate.UnresolvedFunction
import com.databricks.labs.remorph.parsers.{FixedArity, FunctionArity, FunctionBuilder, FunctionType, StandardFunction, UnknownFunction, VariableArity, XmlFunction}
import com.databricks.labs.remorph.parsers._
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers
import org.scalatest.prop.TableDrivenPropertyChecks
Expand All @@ -14,7 +15,6 @@ class FunctionBuilderSpec extends AnyFlatSpec with Matchers with TableDrivenProp
val functions = Table(
("functionName", "expectedArity"), // Header

("MODIFY", Some(FixedArity(1, XmlFunction))),
("ABS", Some(FixedArity(1))),
("ACOS", Some(FixedArity(1))),
("APP_NAME", Some(FixedArity(0))),
Expand Down Expand Up @@ -46,6 +46,7 @@ class FunctionBuilderSpec extends AnyFlatSpec with Matchers with TableDrivenProp
("COT", Some(FixedArity(1))),
("COUNT", Some(FixedArity(1))),
("COUNT_BIG", Some(FixedArity(1))),
("CUME_DIST", Some(FixedArity(0))),
("CURRENT_DATE", Some(FixedArity(0))),
("CURRENT_REQUEST_ID", Some(FixedArity(0))),
("CURRENT_TIMESTAMP", Some(FixedArity(0))),
Expand Down Expand Up @@ -93,6 +94,7 @@ class FunctionBuilderSpec extends AnyFlatSpec with Matchers with TableDrivenProp
("FILEGROUPPROPERTY", Some(FixedArity(2))),
("FILEPROPERTY", Some(FixedArity(2))),
("FILEPROPERTYEX", Some(FixedArity(2))),
("FIRST_VALUE", Some(FixedArity(1))),
("FLOOR", Some(FixedArity(1))),
("FORMAT", Some(VariableArity(2, 3))),
("FORMATMESSAGE", Some(VariableArity(2, Int.MaxValue))),
Expand Down Expand Up @@ -132,6 +134,9 @@ class FunctionBuilderSpec extends AnyFlatSpec with Matchers with TableDrivenProp
("JSON_PATH_EXISTS", Some(FixedArity(2))),
("JSON_QUERY", Some(FixedArity(2))),
("JSON_VALUE", Some(FixedArity(2))),
("LAG", Some(VariableArity(1, 3))),
("LAST_VALUE", Some(FixedArity(1))),
("LEAD", Some(VariableArity(1, 3))),
("LEAST", Some(VariableArity(1, Int.MaxValue))),
("LEFT", Some(FixedArity(2))),
("LEN", Some(FixedArity(1))),
Expand All @@ -143,6 +148,7 @@ class FunctionBuilderSpec extends AnyFlatSpec with Matchers with TableDrivenProp
("MAX", Some(FixedArity(1))),
("MIN", Some(FixedArity(1))),
("MIN_ACTIVE_ROWVERSION", Some(FixedArity(0))),
("MODIFY", Some(FixedArity(1, XmlFunction))),
("MONTH", Some(FixedArity(1))),
("NCHAR", Some(FixedArity(1))),
("NEWID", Some(FixedArity(0))),
Expand All @@ -161,6 +167,9 @@ class FunctionBuilderSpec extends AnyFlatSpec with Matchers with TableDrivenProp
("PARSE", Some(VariableArity(2, 3, convertible = false))),
("PARSENAME", Some(FixedArity(2))),
("PATINDEX", Some(FixedArity(2))),
("PERCENT_RANK", Some(FixedArity(0))),
("PERCENTILE_CONT", Some(FixedArity(1))),
("PERCENTILE_DISC", Some(FixedArity(1))),
("PERMISSIONS", Some(VariableArity(0, 2, convertible = false))),
("PI", Some(FixedArity(0))),
("POWER", Some(FixedArity(2))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,48 @@ class TSqlFunctionSpec extends AnyWordSpec with TSqlParserTestCommon with Matche
example("COUNT(DISTINCT salary)", _.expression(), ir.CallFunction("COUNT", Seq(ir.Distinct(ir.Column("salary")))))
}

// TODO: Analytic functions are next
"translate analytic windowing functions in all forms" ignore {
"translate analytic windowing functions in all forms" in {

example(
query = "FIRST_VALUE(Salary) OVER (PARTITION BY DepartmentID ORDER BY Salary DESC)",
_.expression(),
ir.Window(
ir.CallFunction("FIRST_VALUE", Seq(ir.Column("Salary"))),
Seq(ir.Column("DepartmentID")),
Seq(ir.SortOrder(ir.Column("Salary"), ir.DescendingSortDirection, ir.SortNullsUnspecified)),
ir.WindowFrame(
ir.UndefinedFrame,
ir.FrameBoundary(current_row = false, unbounded = false, ir.Noop),
ir.FrameBoundary(current_row = false, unbounded = false, ir.Noop))))

example(
query = """
LAST_VALUE(salary) OVER (PARTITION BY department_id ORDER BY employee_id DESC)
""",
_.expression(),
ir.Window(
ir.CallFunction("LAST_VALUE", Seq(ir.Column("salary"))),
Seq(ir.Column("department_id")),
Seq(ir.SortOrder(ir.Column("employee_id"), ir.DescendingSortDirection, ir.SortNullsUnspecified)),
ir.WindowFrame(
ir.UndefinedFrame,
ir.FrameBoundary(current_row = false, unbounded = false, ir.Noop),
ir.FrameBoundary(current_row = false, unbounded = false, ir.Noop))))

example(
query = "PERCENTILE_CONT(0.5) WITHIN GROUP (ORDER BY Salary) OVER (PARTITION BY DepartmentID)",
_.expression(),
ir.Window(
ir.WithinGroup(
ir.CallFunction("PERCENTILE_CONT", Seq(ir.Literal(float = Some(0.5f)))),
Seq(ir.SortOrder(ir.Column("Salary"), ir.AscendingSortDirection, ir.SortNullsUnspecified))),
Seq(ir.Column("DepartmentID")),
List(),
ir.WindowFrame(
ir.UndefinedFrame,
ir.FrameBoundary(current_row = false, unbounded = false, ir.Noop),
ir.FrameBoundary(current_row = false, unbounded = false, ir.Noop))))

example(
query = """
LEAD(salary, 1) OVER (PARTITION BY department_id ORDER BY employee_id DESC)
Expand Down

0 comments on commit 0f10799

Please sign in to comment.