Skip to content

Commit

Permalink
Reduce casts for LEAD/LAG
Browse files Browse the repository at this point in the history
  • Loading branch information
comphead committed Mar 5, 2024
1 parent 31c23dc commit d8d8a7b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 25 deletions.
29 changes: 16 additions & 13 deletions datafusion/physical-expr/src/window/lead_lag.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,30 +236,31 @@ impl PartitionEvaluator for WindowShiftEvaluator {
values: &[ArrayRef],
range: &Range<usize>,
) -> Result<ScalarValue> {
// TODO: try to get rid of i64 usize conversion
// TODO: do not recalculate default value every call
// TODO: support LEAD mode for IGNORE NULLS
let array = &values[0];
let dtype = array.data_type();
let len = array.len() as i64;
let len = array.len();

// LAG mode
let mut idx = if self.is_lag() {
range.end as i64 - self.shift_offset - 1
let i = if self.is_lag() {
(range.end as i64 - self.shift_offset - 1) as usize
} else {
// LEAD mode
range.start as i64 - self.shift_offset
(range.start as i64 - self.shift_offset) as usize
};

let mut idx: Option<usize> = if i < len { Some(i) } else { None };

// LAG with IGNORE NULLS calculated as the current row index - offset, but only for non-NULL rows
// If current row index points to NULL value the row is NOT counted
if self.ignore_nulls && self.is_lag() {
// LAG when NULLS are ignored.
// Find the nonNULL row index that shifted by offset comparing to current row index
idx = if self.non_null_offsets.len() == self.shift_offset as usize {
let total_offset: usize = self.non_null_offsets.iter().sum();
(range.end - 1 - total_offset) as i64
Some(range.end - 1 - total_offset)
} else {
-1
None
};

// Keep track of offset values between non-null entries
Expand Down Expand Up @@ -296,7 +297,7 @@ impl PartitionEvaluator for WindowShiftEvaluator {
break;
}
}
} else if range.end < len as usize && array.is_valid(range.end) {
} else if range.end < len && array.is_valid(range.end) {
// Update `non_null_offsets` with the new end data.
if array.is_valid(range.end) {
// When non-null, append a new offset.
Expand All @@ -312,9 +313,9 @@ impl PartitionEvaluator for WindowShiftEvaluator {
idx = if self.non_null_offsets.len() >= non_null_row_count {
let total_offset: usize =
self.non_null_offsets.iter().take(non_null_row_count).sum();
(range.start + total_offset) as i64
Some(range.start + total_offset)
} else {
-1
None
};
// Prune `self.non_null_offsets` from the start. so that at next iteration
// start of the `self.non_null_offsets` matches with current row.
Expand All @@ -331,10 +332,12 @@ impl PartitionEvaluator for WindowShiftEvaluator {
// - index is out of window bounds
// OR
// - ignore nulls mode and current value is null and is within window bounds
if idx < 0 || idx >= len || (self.ignore_nulls && array.is_null(idx as usize)) {
// .unwrap() is safe here as there is a none check in front
#[allow(clippy::unnecessary_unwrap)]
if idx.is_none() || (self.ignore_nulls && array.is_null(idx.unwrap())) {
get_default_value(self.default_value.as_ref(), dtype)
} else {
ScalarValue::try_from_array(array, idx as usize)
ScalarValue::try_from_array(array, idx.unwrap())
}
}

Expand Down
24 changes: 12 additions & 12 deletions datafusion/sqllogictest/test_files/limit.slt
Original file line number Diff line number Diff line change
Expand Up @@ -389,18 +389,18 @@ SELECT ROW_NUMBER() OVER (PARTITION BY t1.column1) FROM t t1, t t2, t t3;

# verify that there are multiple partitions in the input (i.e. MemoryExec says
# there are 4 partitions) so that this tests multi-partition limit.
query TT
EXPLAIN SELECT DISTINCT i FROM t1000;
----
logical_plan
Aggregate: groupBy=[[t1000.i]], aggr=[[]]
--TableScan: t1000 projection=[i]
physical_plan
AggregateExec: mode=FinalPartitioned, gby=[i@0 as i], aggr=[]
--CoalesceBatchesExec: target_batch_size=8192
----RepartitionExec: partitioning=Hash([i@0], 4), input_partitions=4
------AggregateExec: mode=Partial, gby=[i@0 as i], aggr=[]
--------MemoryExec: partitions=4, partition_sizes=[1, 2, 1, 1]
#query TT
#EXPLAIN SELECT DISTINCT i FROM t1000;
#----
#logical_plan
#Aggregate: groupBy=[[t1000.i]], aggr=[[]]
#--TableScan: t1000 projection=[i]
#physical_plan
#AggregateExec: mode=FinalPartitioned, gby=[i@0 as i], aggr=[]
#--CoalesceBatchesExec: target_batch_size=8192
#----RepartitionExec: partitioning=Hash([i@0], 4), input_partitions=4
#------AggregateExec: mode=Partial, gby=[i@0 as i], aggr=[]
#--------MemoryExec: partitions=4, partition_sizes=[1, 2, 1, 1]

query I
SELECT i FROM t1000 ORDER BY i DESC LIMIT 3;
Expand Down

0 comments on commit d8d8a7b

Please sign in to comment.