Skip to content

Commit

Permalink
Change Accumulator::evaluate and Accumulator::state to take `&mut…
Browse files Browse the repository at this point in the history
… self`
  • Loading branch information
alamb committed Jan 21, 2024
1 parent b7e13a0 commit 70ffb2f
Show file tree
Hide file tree
Showing 36 changed files with 117 additions and 111 deletions.
4 changes: 2 additions & 2 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
Expand All @@ -113,7 +113,7 @@ impl Accumulator for GeometricMean {

// DataFusion expects this function to return the final value of this aggregator.
// in this case, this is the formula of the geometric mean
fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let value = self.prod.powf(1.0 / self.n as f64);
Ok(ScalarValue::from(value))
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl Accumulator for GeometricMean {
// This function serializes our state to `ScalarValue`, which DataFusion uses
// to pass this state between execution stages.
// Note that this can be arbitrary data.
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.prod),
ScalarValue::from(self.n),
Expand All @@ -81,7 +81,7 @@ impl Accumulator for GeometricMean {

// DataFusion expects this function to return the final value of this aggregator.
// in this case, this is the formula of the geometric mean
fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let value = self.prod.powf(1.0 / self.n as f64);
Ok(ScalarValue::from(value))
}
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/src/datasource/statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,11 @@ pub(crate) fn get_col_stats(
) -> Vec<ColumnStatistics> {
(0..schema.fields().len())
.map(|i| {
let max_value = match &max_values[i] {
let max_value = match max_values.get_mut(i).unwrap() {
Some(max_value) => max_value.evaluate().ok(),
None => None,
};
let min_value = match &min_values[i] {
let min_value = match min_values.get_mut(i).unwrap() {
Some(min_value) => min_value.evaluate().ok(),
None => None,
};
Expand Down
8 changes: 4 additions & 4 deletions datafusion/core/tests/user_defined/user_defined_aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ impl TimeSum {
}

impl Accumulator for TimeSum {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand All @@ -457,7 +457,7 @@ impl Accumulator for TimeSum {
self.update_batch(states)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
println!("Evaluating to {}", self.sum);
Ok(ScalarValue::TimestampNanosecond(Some(self.sum), None))
}
Expand Down Expand Up @@ -582,14 +582,14 @@ impl FirstSelector {
}

impl Accumulator for FirstSelector {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let state = self.to_state().into_iter().collect::<Vec<_>>();

Ok(state)
}

/// produce the output structure
fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(self.to_scalar())
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ async fn udaf_as_window_func() -> Result<()> {
struct MyAccumulator;

impl Accumulator for MyAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
unimplemented!()
}

Expand All @@ -260,7 +260,7 @@ async fn udaf_as_window_func() -> Result<()> {
unimplemented!()
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
unimplemented!()
}

Expand Down
12 changes: 9 additions & 3 deletions datafusion/expr/src/accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,17 @@ pub trait Accumulator: Send + Sync + Debug {
/// running sum.
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;

/// Returns the final aggregate value.
/// Returns the final aggregate value and resets internal state.
///
/// For example, the `SUM` accumulator maintains a running sum,
/// and `evaluate` will produce that running sum as its output.
fn evaluate(&self) -> Result<ScalarValue>;
///
/// After this call, the accumulator's internal state should be
/// equivalent to when it was first created.
///
/// This function gets a `mut` accumulator to allow for the accumulator to
/// use an arrow compatible internal state when possible.
fn evaluate(&mut self) -> Result<ScalarValue>;

/// Returns the allocated size required for this accumulator, in
/// bytes, including `Self`.
Expand Down Expand Up @@ -129,7 +135,7 @@ pub trait Accumulator: Send + Sync + Debug {
/// Note that [`ScalarValue::List`] can be used to pass multiple
/// values if the number of intermediate values is not known at
/// planning time (e.g. for `MEDIAN`)
fn state(&self) -> Result<Vec<ScalarValue>>;
fn state(&mut self) -> Result<Vec<ScalarValue>>;

/// Updates the accumulator's state from an `Array` containing one
/// or more intermediate values.
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/approx_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -244,12 +244,12 @@ macro_rules! default_accumulator_impl {
Ok(())
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let value = ScalarValue::from(&self.hll);
Ok(vec![value])
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::UInt64(Some(self.hll.count() as u64)))
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -377,7 +377,7 @@ impl ApproxPercentileAccumulator {
}

impl Accumulator for ApproxPercentileAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(self.digest.to_scalar_state().into_iter().collect())
}

Expand All @@ -389,7 +389,7 @@ impl Accumulator for ApproxPercentileAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
if self.digest.count() == 0.0 {
return exec_err!("aggregate function needs at least one non-null element");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl ApproxPercentileWithWeightAccumulator {
}

impl Accumulator for ApproxPercentileWithWeightAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
self.approx_percentile_cont_accumulator.state()
}

Expand All @@ -155,7 +155,7 @@ impl Accumulator for ApproxPercentileWithWeightAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
self.approx_percentile_cont_accumulator.evaluate()
}

Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,11 @@ impl Accumulator for ArrayAggAccumulator {
Ok(())
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
// Transform Vec<ListArr> to ListArr

let element_arrays: Vec<&dyn Array> =
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ impl DistinctArrayAggAccumulator {
}

impl Accumulator for DistinctArrayAggAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand Down Expand Up @@ -163,7 +163,7 @@ impl Accumulator for DistinctArrayAggAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let values: Vec<ScalarValue> = self.values.iter().cloned().collect();
let arr = ScalarValue::new_list(&values, &self.datatype);
Ok(ScalarValue::List(arr))
Expand Down
4 changes: 2 additions & 2 deletions datafusion/physical-expr/src/aggregate/array_agg_ordered.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,13 +279,13 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
Ok(())
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let mut result = vec![self.evaluate()?];
result.push(self.evaluate_orderings()?);
Ok(result)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let values = self.values.clone();
let array = if self.reverse {
ScalarValue::new_list_from_iter(values.into_iter().rev(), &self.datatypes[0])
Expand Down
8 changes: 4 additions & 4 deletions datafusion/physical-expr/src/aggregate/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ pub struct AvgAccumulator {
}

impl Accumulator for AvgAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::Float64(self.sum),
Expand Down Expand Up @@ -277,7 +277,7 @@ impl Accumulator for AvgAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(
self.sum.map(|f| f / self.count as f64),
))
Expand Down Expand Up @@ -315,7 +315,7 @@ impl<T: DecimalType + ArrowNumericType> Debug for DecimalAvgAccumulator<T> {
}

impl<T: DecimalType + ArrowNumericType> Accumulator for DecimalAvgAccumulator<T> {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.count),
ScalarValue::new_primitive::<T>(
Expand Down Expand Up @@ -357,7 +357,7 @@ impl<T: DecimalType + ArrowNumericType> Accumulator for DecimalAvgAccumulator<T>
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let v = self
.sum
.map(|v| {
Expand Down
16 changes: 8 additions & 8 deletions datafusion/physical-expr/src/aggregate/bit_and_or_xor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -190,11 +190,11 @@ where
self.update_batch(states)
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
}

Expand Down Expand Up @@ -339,7 +339,7 @@ impl<T: ArrowNumericType> Accumulator for BitOrAccumulator<T>
where
T::Native: std::ops::BitOr<Output = T::Native>,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand All @@ -355,7 +355,7 @@ where
self.update_batch(states)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
}

Expand Down Expand Up @@ -500,7 +500,7 @@ impl<T: ArrowNumericType> Accumulator for BitXorAccumulator<T>
where
T::Native: std::ops::BitXor<Output = T::Native>,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![self.evaluate()?])
}

Expand All @@ -516,7 +516,7 @@ where
self.update_batch(states)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
ScalarValue::new_primitive::<T>(self.value, &T::DATA_TYPE)
}

Expand Down Expand Up @@ -634,7 +634,7 @@ impl<T: ArrowNumericType> Accumulator for DistinctBitXorAccumulator<T>
where
T::Native: std::ops::BitXor<Output = T::Native> + std::hash::Hash + Eq,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
// 1. Stores aggregate state in `ScalarValue::List`
// 2. Constructs `ScalarValue::List` state from distinct numeric stored in hash set
let state_out = {
Expand Down Expand Up @@ -679,7 +679,7 @@ where
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let mut acc = T::Native::usize_as(0);
for distinct_value in self.values.iter() {
acc = acc ^ *distinct_value;
Expand Down
8 changes: 4 additions & 4 deletions datafusion/physical-expr/src/aggregate/bool_and_or.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,11 +191,11 @@ impl Accumulator for BoolAndAccumulator {
self.update_batch(states)
}

fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::Boolean(self.acc)])
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Boolean(self.acc))
}

Expand Down Expand Up @@ -309,7 +309,7 @@ struct BoolOrAccumulator {
}

impl Accumulator for BoolOrAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::Boolean(self.acc)])
}

Expand All @@ -328,7 +328,7 @@ impl Accumulator for BoolOrAccumulator {
self.update_batch(states)
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Boolean(self.acc))
}

Expand Down
6 changes: 3 additions & 3 deletions datafusion/physical-expr/src/aggregate/correlation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ impl CorrelationAccumulator {
}

impl Accumulator for CorrelationAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![
ScalarValue::from(self.covar.get_count()),
ScalarValue::from(self.covar.get_mean1()),
Expand Down Expand Up @@ -215,7 +215,7 @@ impl Accumulator for CorrelationAccumulator {
Ok(())
}

fn evaluate(&self) -> Result<ScalarValue> {
fn evaluate(&mut self) -> Result<ScalarValue> {
let covar = self.covar.evaluate()?;
let stddev1 = self.stddev1.evaluate()?;
let stddev2 = self.stddev2.evaluate()?;
Expand Down Expand Up @@ -519,7 +519,7 @@ mod tests {
.collect::<Result<Vec<_>>>()?;
accum1.update_batch(&values1)?;
accum2.update_batch(&values2)?;
let state2 = get_accum_scalar_values_as_arrays(accum2.as_ref())?;
let state2 = get_accum_scalar_values_as_arrays(accum2.as_mut())?;
accum1.merge_batch(&state2)?;
accum1.evaluate()
}
Expand Down
Loading

0 comments on commit 70ffb2f

Please sign in to comment.