Skip to content

Commit

Permalink
Address comments.
Browse files Browse the repository at this point in the history
  • Loading branch information
dongjoon-hyun committed Apr 27, 2016
1 parent 0373440 commit 152c6c2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1345,29 +1345,32 @@ object DecimalAggregates extends Rule[LogicalPlan] {

def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
case we @ WindowExpression(ae @ AggregateExpression(Sum(
e @ DecimalType.Expression(prec, scale)), _, _, _), _) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
prec + 10, scale)

case we @ WindowExpression(ae @ AggregateExpression(Average(
e @ DecimalType.Expression(prec, scale)), _, _, _), _) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr =
we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e))))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))

case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)

case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _)
if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))
case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _), _) => af match {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
prec + 10, scale)

case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr =
we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e))))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))

case _ => we
}
case ae @ AggregateExpression(af, _, _, _) => af match {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)

case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
Cast(
Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
DecimalType(prec + 4, scale + 4))

case _ => ae
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,6 @@ case class WindowExec(
case e @ WindowExpression(function, spec) =>
val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame]
function match {
case MakeDecimal(AggregateExpression(f, _, _, _), prec, scale) =>
collect("AGGREGATE", frame, e, f)
case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f)
case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f)
case f: OffsetWindowFunction => collect("OFFSET", frame, e, f)
Expand Down

0 comments on commit 152c6c2

Please sign in to comment.