Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-49836][SQL][SS] Fix possibly broken query when window is provided to window/session_window fn #48309

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -87,85 +87,86 @@ object TimeWindowing extends Rule[LogicalPlan] {

val window = windowExpressions.head

// time window is provided as time column of window function, replace it with WindowTime
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reviewer: except this code comment, the only change is to remove return and replace it with a huge if-else statement.

if (StructType.acceptsType(window.timeColumn.dataType)) {
return p.transformExpressions {
p.transformExpressions {
case t: TimeWindow => t.copy(timeColumn = WindowTime(window.timeColumn))
}
}

val metadata = window.timeColumn match {
case a: Attribute => a.metadata
case _ => Metadata.empty
}

val newMetadata = new MetadataBuilder()
.withMetadata(metadata)
.putBoolean(TimeWindow.marker, true)
.build()
} else {
val metadata = window.timeColumn match {
case a: Attribute => a.metadata
case _ => Metadata.empty
}

def getWindow(i: Int, dataType: DataType): Expression = {
val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType)
val remainder = (timestamp - window.startTime) % window.slideDuration
val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
remainder + window.slideDuration)), Some(remainder))
val windowStart = lastStart - i * window.slideDuration
val windowEnd = windowStart + window.windowDuration
val newMetadata = new MetadataBuilder()
.withMetadata(metadata)
.putBoolean(TimeWindow.marker, true)
.build()

// We make sure value fields are nullable since the dataType of TimeWindow defines them
// as nullable.
CreateNamedStruct(
Literal(WINDOW_START) ::
PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() ::
Literal(WINDOW_END) ::
PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() ::
Nil)
}
def getWindow(i: Int, dataType: DataType): Expression = {
val timestamp = PreciseTimestampConversion(window.timeColumn, dataType, LongType)
val remainder = (timestamp - window.startTime) % window.slideDuration
val lastStart = timestamp - CaseWhen(Seq((LessThan(remainder, 0),
remainder + window.slideDuration)), Some(remainder))
val windowStart = lastStart - i * window.slideDuration
val windowEnd = windowStart + window.windowDuration

// We make sure value fields are nullable since the dataType of TimeWindow defines them
// as nullable.
CreateNamedStruct(
Literal(WINDOW_START) ::
PreciseTimestampConversion(windowStart, LongType, dataType).castNullable() ::
Literal(WINDOW_END) ::
PreciseTimestampConversion(windowEnd, LongType, dataType).castNullable() ::
Nil)
}

val windowAttr = AttributeReference(
WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()
val windowAttr = AttributeReference(
WINDOW_COL_NAME, window.dataType, metadata = newMetadata)()

if (window.windowDuration == window.slideDuration) {
val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)(
exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))
if (window.windowDuration == window.slideDuration) {
val windowStruct = Alias(getWindow(0, window.timeColumn.dataType), WINDOW_COL_NAME)(
exprId = windowAttr.exprId, explicitMetadata = Some(newMetadata))

val replacedPlan = p transformExpressions {
case t: TimeWindow => windowAttr
}
val replacedPlan = p transformExpressions {
case t: TimeWindow => windowAttr
}

// For backwards compatibility we add a filter to filter out nulls
val filterExpr = IsNotNull(window.timeColumn)
// For backwards compatibility we add a filter to filter out nulls
val filterExpr = IsNotNull(window.timeColumn)

replacedPlan.withNewChildren(
Project(windowStruct +: child.output,
Filter(filterExpr, child)) :: Nil)
} else {
val overlappingWindows =
math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
val windows =
Seq.tabulate(overlappingWindows)(i =>
getWindow(i, window.timeColumn.dataType))

val projections = windows.map(_ +: child.output)

// When the condition windowDuration % slideDuration = 0 is fulfilled,
// the estimation of the number of windows becomes exact one,
// which means all produced windows are valid.
val filterExpr =
if (window.windowDuration % window.slideDuration == 0) {
IsNotNull(window.timeColumn)
replacedPlan.withNewChildren(
Project(windowStruct +: child.output,
Filter(filterExpr, child)) :: Nil)
} else {
window.timeColumn >= windowAttr.getField(WINDOW_START) &&
window.timeColumn < windowAttr.getField(WINDOW_END)
val overlappingWindows =
math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt
val windows =
Seq.tabulate(overlappingWindows)(i =>
getWindow(i, window.timeColumn.dataType))

val projections = windows.map(_ +: child.output)

// When the condition windowDuration % slideDuration = 0 is fulfilled,
// the estimation of the number of windows becomes exact one,
// which means all produced windows are valid.
val filterExpr =
if (window.windowDuration % window.slideDuration == 0) {
IsNotNull(window.timeColumn)
} else {
window.timeColumn >= windowAttr.getField(WINDOW_START) &&
window.timeColumn < windowAttr.getField(WINDOW_END)
}

val substitutedPlan = Filter(filterExpr,
Expand(projections, windowAttr +: child.output, child))

val renamedPlan = p transformExpressions {
case t: TimeWindow => windowAttr
}

renamedPlan.withNewChildren(substitutedPlan :: Nil)
}

val substitutedPlan = Filter(filterExpr,
Expand(projections, windowAttr +: child.output, child))

val renamedPlan = p transformExpressions {
case t: TimeWindow => windowAttr
}

renamedPlan.withNewChildren(substitutedPlan :: Nil)
}
} else if (numWindowExpr > 1) {
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
Expand Down Expand Up @@ -210,74 +211,74 @@ object SessionWindowing extends Rule[LogicalPlan] {
val session = sessionExpressions.head

if (StructType.acceptsType(session.timeColumn.dataType)) {
return p transformExpressions {
p transformExpressions {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For reviewer: the only change is to remove return and replace it with a huge if-else statement.

case t: SessionWindow => t.copy(timeColumn = WindowTime(session.timeColumn))
}
}
} else {
val metadata = session.timeColumn match {
case a: Attribute => a.metadata
case _ => Metadata.empty
}

val metadata = session.timeColumn match {
case a: Attribute => a.metadata
case _ => Metadata.empty
}
val newMetadata = new MetadataBuilder()
.withMetadata(metadata)
.putBoolean(SessionWindow.marker, true)
.build()

val newMetadata = new MetadataBuilder()
.withMetadata(metadata)
.putBoolean(SessionWindow.marker, true)
.build()

val sessionAttr = AttributeReference(
SESSION_COL_NAME, session.dataType, metadata = newMetadata)()

val sessionStart =
PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType)
val gapDuration = session.gapDuration match {
case expr if expr.dataType == CalendarIntervalType =>
expr
case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
Cast(expr, CalendarIntervalType)
case other =>
throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
}
val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
session.timeColumn.dataType, LongType)

// We make sure value fields are nullable since the dataType of SessionWindow defines them
// as nullable.
val literalSessionStruct = CreateNamedStruct(
Literal(SESSION_START) ::
PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType)
.castNullable() ::
Literal(SESSION_END) ::
PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType)
.castNullable() ::
Nil)

val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))
val sessionAttr = AttributeReference(
SESSION_COL_NAME, session.dataType, metadata = newMetadata)()

val sessionStart =
PreciseTimestampConversion(session.timeColumn, session.timeColumn.dataType, LongType)
val gapDuration = session.gapDuration match {
case expr if expr.dataType == CalendarIntervalType =>
expr
case expr if Cast.canCast(expr.dataType, CalendarIntervalType) =>
Cast(expr, CalendarIntervalType)
case other =>
throw QueryCompilationErrors.sessionWindowGapDurationDataTypeError(other.dataType)
}
val sessionEnd = PreciseTimestampConversion(session.timeColumn + gapDuration,
session.timeColumn.dataType, LongType)

val replacedPlan = p transformExpressions {
case s: SessionWindow => sessionAttr
}
// We make sure value fields are nullable since the dataType of SessionWindow defines them
// as nullable.
val literalSessionStruct = CreateNamedStruct(
Literal(SESSION_START) ::
PreciseTimestampConversion(sessionStart, LongType, session.timeColumn.dataType)
.castNullable() ::
Literal(SESSION_END) ::
PreciseTimestampConversion(sessionEnd, LongType, session.timeColumn.dataType)
.castNullable() ::
Nil)

val filterByTimeRange = if (gapDuration.foldable) {
val interval = gapDuration.eval().asInstanceOf[CalendarInterval]
interval == null || interval.months + interval.days + interval.microseconds <= 0
} else {
true
}
val sessionStruct = Alias(literalSessionStruct, SESSION_COL_NAME)(
exprId = sessionAttr.exprId, explicitMetadata = Some(newMetadata))

// As same as tumbling window, we add a filter to filter out nulls.
// And we also filter out events with negative or zero or invalid gap duration.
val filterExpr = if (filterByTimeRange) {
IsNotNull(session.timeColumn) &&
(sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START))
} else {
IsNotNull(session.timeColumn)
}
val replacedPlan = p transformExpressions {
case s: SessionWindow => sessionAttr
}

replacedPlan.withNewChildren(
Filter(filterExpr,
Project(sessionStruct +: child.output, child)) :: Nil)
val filterByTimeRange = if (gapDuration.foldable) {
val interval = gapDuration.eval().asInstanceOf[CalendarInterval]
interval == null || interval.months + interval.days + interval.microseconds <= 0
} else {
true
}

// As same as tumbling window, we add a filter to filter out nulls.
// And we also filter out events with negative or zero or invalid gap duration.
val filterExpr = if (filterByTimeRange) {
IsNotNull(session.timeColumn) &&
(sessionAttr.getField(SESSION_END) > sessionAttr.getField(SESSION_START))
} else {
IsNotNull(session.timeColumn)
}

replacedPlan.withNewChildren(
Filter(filterExpr,
Project(sessionStruct +: child.output, child)) :: Nil)
}
} else if (numWindowExpr > 1) {
throw QueryCompilationErrors.multiTimeWindowExpressionsNotSupportedError(p)
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -547,4 +547,60 @@ class DataFrameSessionWindowingSuite extends QueryTest with SharedSparkSession
}
}
}

// scalastyle:off line.size.limit
// DISCLAIM: This is a revision of below test, which was a part of report in the dev mailing
// list. CREDIT goes to @andrezjzera.
// https://github.com/andrzejzera/spark-bugs/blob/abae7a3839326a8eafc7516a51aca5e0c79282a6/spark-3.5/src/test/scala/SqlSyntaxTest.scala#L122-L165
// scalastyle:on
test("SPARK-49836 using window fn with window as parameter should preserve parent operator") {
withTempView("clicks") {
val df = Seq(
// small window: [00:00, 01:00), user1, 2
("2024-09-30 00:00:00", "user1"), ("2024-09-30 00:00:30", "user1"),
// small window: [01:00, 02:00), user2, 2
("2024-09-30 00:01:00", "user2"), ("2024-09-30 00:01:30", "user2"),
// small window: [03:00, 04:00), user1, 1
("2024-09-30 00:03:30", "user1"),
// small window: [11:00, 12:00), user1, 3
("2024-09-30 00:11:00", "user1"), ("2024-09-30 00:11:30", "user1"),
("2024-09-30 00:11:45", "user1")
).toDF("eventTime", "userId")

// session window: (01:00, 09:00), user1, 3 / (02:00, 07:00), user2, 2 /
// (12:00, 12:05), user1, 3

df.createOrReplaceTempView("clicks")

val aggregatedData = spark.sql(
"""
|SELECT
| userId,
| avg(cpu_large.numClicks) AS clicksPerSession
|FROM
|(
| SELECT
| session_window(small_window, '5 minutes') AS session,
| userId,
| sum(numClicks) AS numClicks
| FROM
| (
| SELECT
| window(eventTime, '1 minute') AS small_window,
| userId,
| count(*) AS numClicks
| FROM clicks
| GROUP BY window, userId
| ) cpu_small
| GROUP BY session_window, userId
|) cpu_large
|GROUP BY userId
|""".stripMargin)

checkAnswer(
aggregatedData,
Seq(Row("user1", 3), Row("user2", 2))
)
}
}
}
Loading