Skip to content

Commit

Permalink
Support STRING order-by columns for RANGE window functions (#8182)
Browse files Browse the repository at this point in the history
Closes #7883.
Depends on rapidsai/cudf#13143, rapidsai/cudf#13199.

This commit adds support for `STRING` order-by columns for RANGE window functions.

Before this commit, only numeric and timestamp types were supported as order-by columns in window specifications. However, it is possible to specify window frames such as follows:
```sql
SELECT COUNT(1) OVER( PARTITION BY gid ORDER BY str_col )
```
The implicit range here is `UNBOUNDED PRECEDING AND CURRENT ROW`, although explicit bounds may also be specified.
Note that range values cannot be specified here, because `STRING` does not support intervals.

This change should now allow the plugin to support `UNBOUNDED PRECEDING`, `UNBOUNDED FOLLOWING`, and `CURRENT ROW` as range window bounds, when the order-by column is `STRING`.

Signed-off-by: MithunR <[email protected]>
  • Loading branch information
mythrocks authored Apr 27, 2023
1 parent 9b9ad17 commit 2bd9b17
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 3 deletions.
47 changes: 47 additions & 0 deletions integration_tests/src/main/python/window_function_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,53 @@ def test_window_aggs_for_rows(data_gen, batch_size):
conf = conf)


@ignore_order(local=True)
@pytest.mark.parametrize('batch_size', ['1000', '1g'], ids=idfn)
@pytest.mark.parametrize('data_gen', [
[('grp', RepeatSeqGen(int_gen, length=20)), # Grouping column.
('ord', LongRangeGen(nullable=True)), # Order-by column (after cast to STRING).
('agg', IntegerGen())] # Aggregation column.
], ids=idfn)
def test_range_windows_with_string_order_by_column(data_gen, batch_size):
"""
Tests that RANGE window functions can be used with STRING order-by columns.
"""
assert_gpu_and_cpu_are_equal_sql(
lambda spark: gen_df(spark, data_gen, length=2048),
'window_agg_table',
'SELECT '
' ROW_NUMBER() OVER '
' (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC ) as row_num_asc, '
' RANK() OVER '
' (PARTITION BY grp ORDER BY CAST(ord AS STRING) DESC ) as rank_desc, '
' DENSE_RANK() OVER '
' (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC ) as dense_rank_asc, '
' COUNT(1) OVER '
' (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC ) as count_1_asc_default, '
' COUNT(agg) OVER '
' (PARTITION BY grp ORDER BY CAST(ord AS STRING) DESC ) as count_desc_default, '
' SUM(agg) OVER '
' (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC '
' RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as sum_asc_UNB_to_CURRENT, '
' MIN(agg) OVER '
' (PARTITION BY grp ORDER BY CAST(ord AS STRING) DESC '
' RANGE BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as min_desc_UNB_to_CURRENT, '
' MAX(agg) OVER '
' (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC '
' RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as max_asc_CURRENT_to_UNB, '
' COUNT(1) OVER '
' (PARTITION BY grp ORDER BY CAST(ord AS STRING) DESC '
' RANGE BETWEEN CURRENT ROW AND UNBOUNDED FOLLOWING) as count_1_desc_CURRENT_to_UNB, '
' COUNT(1) OVER '
' (PARTITION BY grp ORDER BY CAST(ord AS STRING) ASC '
' RANGE BETWEEN CURRENT ROW AND CURRENT ROW) as count_1_asc_CURRENT_to_CURRENT, '
' COUNT(1) OVER '
' (PARTITION BY grp ORDER BY CAST(ord AS STRING) DESC '
' RANGE BETWEEN CURRENT ROW AND CURRENT ROW) as count_1_desc_CURRENT_to_CURRENT '
' FROM window_agg_table ',
conf={'spark.rapids.sql.batchSizeBytes': batch_size})


# This is for aggregations that work with a running window optimization. They don't need to be batched
# specially, but it only works if all of the aggregations can support this.
# the order returned should be consistent because the data ends up in a single task (no partitioning)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -684,13 +684,21 @@ object GroupedAggregations {
if (preceding.isEmpty) {
windowOptionBuilder.unboundedPreceding()
} else {
windowOptionBuilder.preceding(preceding.get)
if (orderType == DType.STRING) { // Bounded STRING bounds can only mean "CURRENT ROW".
windowOptionBuilder.currentRowPreceding()
} else {
windowOptionBuilder.preceding(preceding.get)
}
}

if (following.isEmpty) {
windowOptionBuilder.unboundedFollowing()
} else {
windowOptionBuilder.following(following.get)
if (orderType == DType.STRING) { // Bounded STRING bounds can only mean "CURRENT ROW".
windowOptionBuilder.currentRowFollowing()
} else {
windowOptionBuilder.following(following.get)
}
}

if (orderExpr.isAscending) {
Expand Down Expand Up @@ -763,6 +771,9 @@ object GroupedAggregations {
Scalar.fromDecimal(x.getScale, valueLong.get)
case x if x.getTypeId == DType.DTypeEnum.DECIMAL128 =>
Scalar.fromDecimal(x.getScale, bound.value.left.get.underlying())
case x if x.getTypeId == DType.DTypeEnum.STRING =>
// Not UNBOUNDED. The only other supported boundary for String is CURRENT ROW, i.e. 0.
Scalar.fromString("")
case _ => throw new RuntimeException(s"Not supported order by type, Found $orderByType")
}
Some(s)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ abstract class GpuWindowExpressionMetaBase(
val orderByTypeSupported = orderSpec.forall { so =>
so.dataType match {
case ByteType | ShortType | IntegerType | LongType |
DateType | TimestampType | DecimalType() => true
DateType | TimestampType | StringType | DecimalType() => true
case _ => false
}
}
Expand Down

0 comments on commit 2bd9b17

Please sign in to comment.