Skip to content

Commit

Permalink
supports specialized count(0) (#619)
Browse files Browse the repository at this point in the history
Co-authored-by: zhangli20 <[email protected]>
  • Loading branch information
richox and zhangli20 authored Oct 22, 2024
1 parent f222016 commit 549a59b
Show file tree
Hide file tree
Showing 15 changed files with 259 additions and 21 deletions.
3 changes: 2 additions & 1 deletion native-engine/datafusion-ext-plans/src/agg/agg_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,11 +371,12 @@ impl AggContext {
pub fn partial_update_input_all(
&self,
acc: &mut RefAccumStateRow,
num_rows: usize,
input_arrays: &[Vec<ArrayRef>],
) -> Result<()> {
if self.need_partial_update {
for (idx, agg) in &self.need_partial_update_aggs {
agg.partial_update_all(acc, &input_arrays[*idx])?;
agg.partial_update_all(acc, num_rows, &input_arrays[*idx])?;
}
}
Ok(())
Expand Down
11 changes: 8 additions & 3 deletions native-engine/datafusion-ext-plans/src/agg/avg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,14 @@ impl Agg for AggAvg {
Ok(())
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
self.agg_sum.partial_update_all(acc, values)?;
self.agg_count.partial_update_all(acc, values)?;
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
num_rows: usize,
values: &[ArrayRef],
) -> Result<()> {
self.agg_sum.partial_update_all(acc, num_rows, values)?;
self.agg_count.partial_update_all(acc, num_rows, values)?;
Ok(())
}

Expand Down
7 changes: 6 additions & 1 deletion native-engine/datafusion-ext-plans/src/agg/bloom_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,12 @@ impl Agg for AggBloomFilter {
df_unimplemented_err!("AggBloomFilter::partial_update is not implemented")
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
_num_rows: usize,
values: &[ArrayRef],
) -> Result<()> {
let bloom_filter = match acc.dyn_value_mut(self.accum_state_val_addr) {
Some(v) => downcast_any!(v, mut SparkBloomFilter)?,
v @ None => {
Expand Down
14 changes: 11 additions & 3 deletions native-engine/datafusion-ext-plans/src/agg/brickhouse/collect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,22 @@ impl Agg for AggCollect {
) -> Result<()> {
let list = values[0].as_list::<i32>();
if list.is_valid(row_idx) {
let num_rows = 0; // unused
self.innert_collect_list
.partial_update_all(acc, &[list.value(row_idx)])?;
.partial_update_all(acc, num_rows, &[list.value(row_idx)])?;
}
Ok(())
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
self.innert_collect_list.partial_update_all(acc, values)
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
_num_rows: usize,
values: &[ArrayRef],
) -> Result<()> {
let num_rows = 0; // unused
self.innert_collect_list
.partial_update_all(acc, num_rows, values)
}

fn partial_merge(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,14 +106,22 @@ impl Agg for AggCombineUnique {
) -> Result<()> {
let list = values[0].as_list::<i32>();
if list.is_valid(row_idx) {
let num_rows = 0; // unused
self.innert_collect_set
.partial_update_all(acc, &[list.value(row_idx)])?;
.partial_update_all(acc, num_rows, &[list.value(row_idx)])?;
}
Ok(())
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
self.innert_collect_set.partial_update_all(acc, values)
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
_num_rows: usize,
values: &[ArrayRef],
) -> Result<()> {
let num_rows = 0; // unused
self.innert_collect_set
.partial_update_all(acc, num_rows, values)
}

fn partial_merge(
Expand Down
7 changes: 6 additions & 1 deletion native-engine/datafusion-ext-plans/src/agg/collect_list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,12 @@ impl Agg for AggCollectList {
Ok(())
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
_num_rows: usize,
values: &[ArrayRef],
) -> Result<()> {
let dyn_list = match acc.dyn_value_mut(self.accum_state_val_addr) {
Some(dyn_list) => dyn_list,
w => {
Expand Down
7 changes: 6 additions & 1 deletion native-engine/datafusion-ext-plans/src/agg/collect_set.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,12 @@ impl Agg for AggCollectSet {
Ok(())
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
_num_rows: usize,
values: &[ArrayRef],
) -> Result<()> {
let dyn_set = match acc.dyn_value_mut(self.accum_state_val_addr) {
Some(dyn_set) => dyn_set,
w => {
Expand Down
7 changes: 6 additions & 1 deletion native-engine/datafusion-ext-plans/src/agg/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,12 @@ impl Agg for AggCount {
Ok(())
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
_num_rows: usize,
values: &[ArrayRef],
) -> Result<()> {
let addr = self.accum_state_val_addr;
let num_valids = values[0].len() - values[0].null_count();
acc.update_fixed_value::<i64>(addr, |v| v + num_valids as i64);
Expand Down
164 changes: 164 additions & 0 deletions native-engine/datafusion-ext-plans/src/agg/count0.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
// Copyright 2022 The Blaze Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

use std::{
any::Any,
fmt::{Debug, Formatter},
sync::{atomic::AtomicUsize, Arc},
};

use arrow::{array::*, datatypes::*};
use datafusion::{
common::{Result, ScalarValue},
physical_expr::PhysicalExpr,
};

use crate::agg::{
acc::{AccumInitialValue, AccumStateRow, AccumStateValAddr, RefAccumStateRow},
Agg, WithAggBufAddrs, WithMemTracking,
};

pub struct AggCount0 {
accum_state_val_addr: AccumStateValAddr,
mem_used_tracker: AtomicUsize,
}

impl WithAggBufAddrs for AggCount0 {
fn set_accum_state_val_addrs(&mut self, accum_state_val_addrs: &[AccumStateValAddr]) {
self.accum_state_val_addr = accum_state_val_addrs[0];
}
}

impl WithMemTracking for AggCount0 {
fn mem_used_tracker(&self) -> &AtomicUsize {
&self.mem_used_tracker
}
}

impl AggCount0 {
pub fn try_new(data_type: DataType) -> Result<Self> {
assert_eq!(data_type, DataType::Int64);
Ok(Self {
accum_state_val_addr: AccumStateValAddr::default(),
mem_used_tracker: AtomicUsize::new(0),
})
}
}

impl Debug for AggCount0 {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "Count(0)")
}
}

impl Agg for AggCount0 {
fn as_any(&self) -> &dyn Any {
self
}

fn exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![]
}

fn with_new_exprs(&self, _exprs: Vec<Arc<dyn PhysicalExpr>>) -> Result<Arc<dyn Agg>> {
Ok(Arc::new(Self::try_new(DataType::Int64)?))
}

fn data_type(&self) -> &DataType {
&DataType::Int64
}

fn nullable(&self) -> bool {
false
}

fn accums_initial(&self) -> &[AccumInitialValue] {
&[AccumInitialValue::Scalar(ScalarValue::Int64(Some(0)))]
}

fn increase_acc_mem_used(&self, _acc: &mut RefAccumStateRow) {
// do nothing
}

fn partial_update(
&self,
acc: &mut RefAccumStateRow,
_values: &[ArrayRef],
_row_idx: usize,
) -> Result<()> {
let addr = self.accum_state_val_addr;
acc.update_fixed_value::<i64>(addr, |v| v + 1);
Ok(())
}

fn partial_batch_update(
&self,
accs: &mut [RefAccumStateRow],
_values: &[ArrayRef],
) -> Result<()> {
let addr = self.accum_state_val_addr;
for acc in accs.iter_mut() {
acc.update_fixed_value::<i64>(addr, |v| v + 1);
}
Ok(())
}

fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
num_rows: usize,
_values: &[ArrayRef],
) -> Result<()> {
let addr = self.accum_state_val_addr;
acc.update_fixed_value::<i64>(addr, |v| v + num_rows as i64);
Ok(())
}

fn partial_merge(
&self,
acc1: &mut RefAccumStateRow,
acc2: &mut RefAccumStateRow,
) -> Result<()> {
let addr = self.accum_state_val_addr;
let num_valids2 = acc2.fixed_value::<i64>(addr);
acc1.update_fixed_value::<i64>(addr, |v| v + num_valids2);
Ok(())
}

fn partial_batch_merge(
&self,
accs: &mut [RefAccumStateRow],
merging_accs: &mut [RefAccumStateRow],
) -> Result<()> {
let addr = self.accum_state_val_addr;
for (acc, merging_acc) in accs.iter_mut().zip(merging_accs) {
let merging_num_valids = merging_acc.fixed_value::<i64>(addr);
acc.update_fixed_value::<i64>(addr, |v| v + merging_num_valids);
}
Ok(())
}
fn final_merge(&self, acc: &mut RefAccumStateRow) -> Result<ScalarValue> {
let addr = self.accum_state_val_addr;
Ok(ScalarValue::from(acc.fixed_value::<i64>(addr)))
}

fn final_batch_merge(&self, accs: &mut [RefAccumStateRow]) -> Result<ArrayRef> {
let addr = self.accum_state_val_addr;
Ok(Arc::new(
accs.iter()
.map(|acc| acc.fixed_value::<i64>(addr))
.collect::<Int64Array>(),
))
}
}
7 changes: 6 additions & 1 deletion native-engine/datafusion-ext-plans/src/agg/first.rs
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,12 @@ impl Agg for AggFirst {
Ok(())
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
_num_rows: usize,
values: &[ArrayRef],
) -> Result<()> {
if !self.is_touched(acc) {
let value = &values[0];
if !value.is_empty() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,12 @@ impl Agg for AggFirstIgnoresNull {
Ok(())
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
_num_rows: usize,
values: &[ArrayRef],
) -> Result<()> {
let partial_updater = self.partial_updater;
let value = &values[0];

Expand Down
7 changes: 6 additions & 1 deletion native-engine/datafusion-ext-plans/src/agg/maxmin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,12 @@ impl<P: AggMaxMinParams> Agg for AggMaxMin<P> {
Ok(())
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
_num_rows: usize,
values: &[ArrayRef],
) -> Result<()> {
macro_rules! handle_fixed {
($ty:ident, $maxfun:expr) => {{
type TArray = paste! {[<$ty Array>]};
Expand Down
15 changes: 13 additions & 2 deletions native-engine/datafusion-ext-plans/src/agg/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ pub mod brickhouse;
pub mod collect_list;
pub mod collect_set;
pub mod count;
pub mod count0;
pub mod first;
pub mod first_ignores_null;
pub mod maxmin;
Expand Down Expand Up @@ -168,7 +169,13 @@ pub trait Agg: WithAggBufAddrs + WithMemTracking + Send + Sync + Debug {
Ok(())
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()>;
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
num_rows: usize,
values: &[ArrayRef],
) -> Result<()>;

fn partial_merge(
&self,
acc: &mut RefAccumStateRow,
Expand Down Expand Up @@ -202,7 +209,11 @@ pub fn create_agg(
Ok(match agg_function {
AggFunction::Count => {
let return_type = DataType::Int64;
Arc::new(count::AggCount::try_new(children[0].clone(), return_type)?)
if children[0].nullable(input_schema)? {
Arc::new(count::AggCount::try_new(children[0].clone(), return_type)?)
} else {
Arc::new(count0::AggCount0::try_new(return_type)?)
}
}
AggFunction::Sum => {
let arg_type = children[0].data_type(input_schema)?;
Expand Down
7 changes: 6 additions & 1 deletion native-engine/datafusion-ext-plans/src/agg/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,12 @@ impl Agg for AggSum {
Ok(())
}

fn partial_update_all(&self, acc: &mut RefAccumStateRow, values: &[ArrayRef]) -> Result<()> {
fn partial_update_all(
&self,
acc: &mut RefAccumStateRow,
_num_rows: usize,
values: &[ArrayRef],
) -> Result<()> {
macro_rules! handle {
($ty:ident) => {{
type TArray = paste! {[<$ty Array>]};
Expand Down
Loading

0 comments on commit 549a59b

Please sign in to comment.