Skip to content

Commit

Permalink
Review unwrap and panic from the aggregate directory of `datafu…
Browse files Browse the repository at this point in the history
…sion-physical-expr` (#3443)

* approx agg

* expand downcast_value

* Update datafusion/common/src/lib.rs

* Update datafusion/common/src/lib.rs

* agg almost done

* agg done

Co-authored-by: Andy Grove <[email protected]>
  • Loading branch information
chloeandmargaret and andygrove authored Sep 12, 2022
1 parent a8c1579 commit c5c1dae
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 44 deletions.
16 changes: 16 additions & 0 deletions datafusion/common/src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -363,3 +363,19 @@ macro_rules! internal_err {
Err(DataFusionError::Internal(format!($($arg)*)))
};
}

/// Unwrap an `Option` if possible. Otherwise return an `DataFusionError::Internal`.
/// In normal usage of DataFusion the unwrap should always succeed.
///
/// Example: `let values = unwrap_or_internal_err!(values)`
#[macro_export]
macro_rules! unwrap_or_internal_err {
($Value: ident) => {
$Value.ok_or_else(|| {
DataFusionError::Internal(format!(
"{} should not be None",
stringify!($Value)
))
})?
};
}
4 changes: 2 additions & 2 deletions datafusion/common/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema};
pub use error::{field_not_found, DataFusionError, Result, SchemaError};
pub use scalar::{ScalarType, ScalarValue};

/// Downcast an Arrow Array to a concrete type, return an `Err` if the cast is
/// not possible.
/// Downcast an Arrow Array to a concrete type, return an `DataFusionError::Internal` if the cast is
/// not possible. In normal usage of DataFusion the downcast should always succeed.
///
/// Example: `let array = downcast_value!(values, Int32Array)`
#[macro_export]
Expand Down
36 changes: 19 additions & 17 deletions datafusion/physical-expr/src/aggregate/covariance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use arrow::{
datatypes::DataType,
datatypes::Field,
};
use datafusion_common::{downcast_value, ScalarValue};
use datafusion_common::{downcast_value, unwrap_or_internal_err, ScalarValue};
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{Accumulator, AggregateState};

Expand Down Expand Up @@ -266,29 +266,31 @@ impl Accumulator for CovarianceAccumulator {
"The two columns are not aligned".to_string(),
));
}
} else {
let value1 = unwrap_or_internal_err!(value1);
let value2 = unwrap_or_internal_err!(value2);
let new_count = self.count + 1;
let delta1 = value1 - self.mean1;
let new_mean1 = delta1 / new_count as f64 + self.mean1;
let delta2 = value2 - self.mean2;
let new_mean2 = delta2 / new_count as f64 + self.mean2;
let new_c = delta1 * (value2 - new_mean2) + self.algo_const;

self.count += 1;
self.mean1 = new_mean1;
self.mean2 = new_mean2;
self.algo_const = new_c;
}

let new_count = self.count + 1;
let delta1 = value1.unwrap() - self.mean1;
let new_mean1 = delta1 / new_count as f64 + self.mean1;
let delta2 = value2.unwrap() - self.mean2;
let new_mean2 = delta2 / new_count as f64 + self.mean2;
let new_c = delta1 * (value2.unwrap() - new_mean2) + self.algo_const;

self.count += 1;
self.mean1 = new_mean1;
self.mean2 = new_mean2;
self.algo_const = new_c;
}

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let counts = states[0].as_any().downcast_ref::<UInt64Array>().unwrap();
let means1 = states[1].as_any().downcast_ref::<Float64Array>().unwrap();
let means2 = states[2].as_any().downcast_ref::<Float64Array>().unwrap();
let cs = states[3].as_any().downcast_ref::<Float64Array>().unwrap();
let counts = downcast_value!(states[0], UInt64Array);
let means1 = downcast_value!(states[1], Float64Array);
let means2 = downcast_value!(states[2], Float64Array);
let cs = downcast_value!(states[3], Float64Array);

for i in 0..counts.len() {
let c = counts.value(i);
Expand Down
12 changes: 6 additions & 6 deletions datafusion/physical-expr/src/aggregate/min_max.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! Defines physical expressions that can evaluated at runtime during query execution
use std::any::Any;
use std::any::{type_name, Any};
use std::convert::TryFrom;
use std::sync::Arc;

Expand All @@ -35,7 +35,7 @@ use arrow::{
datatypes::Field,
};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_common::{downcast_value, DataFusionError, Result};
use datafusion_expr::{Accumulator, AggregateState};

use crate::aggregate::row_accumulator::RowAccumulator;
Expand Down Expand Up @@ -145,7 +145,7 @@ impl AggregateExpr for Max {
// Statically-typed version of min/max(array) -> ScalarValue for string types.
macro_rules! typed_min_max_batch_string {
($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
let array = $VALUES.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
let array = downcast_value!($VALUES, $ARRAYTYPE);
let value = compute::$OP(array);
let value = value.and_then(|e| Some(e.to_string()));
ScalarValue::$SCALAR(value)
Expand All @@ -155,13 +155,13 @@ macro_rules! typed_min_max_batch_string {
// Statically-typed version of min/max(array) -> ScalarValue for non-string types.
macro_rules! typed_min_max_batch {
($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident) => {{
let array = $VALUES.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
let array = downcast_value!($VALUES, $ARRAYTYPE);
let value = compute::$OP(array);
ScalarValue::$SCALAR(value)
}};

($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident, $OP:ident, $TZ:expr) => {{
let array = $VALUES.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
let array = downcast_value!($VALUES, $ARRAYTYPE);
let value = compute::$OP(array);
ScalarValue::$SCALAR(value, $TZ.clone())
}};
Expand All @@ -176,7 +176,7 @@ macro_rules! typed_min_max_batch_decimal128 {
if null_count == $VALUES.len() {
ScalarValue::Decimal128(None, *$PRECISION, *$SCALE)
} else {
let array = $VALUES.as_any().downcast_ref::<Decimal128Array>().unwrap();
let array = downcast_value!($VALUES, Decimal128Array);
if null_count == 0 {
// there is no null value
let mut result = array.value(0);
Expand Down
7 changes: 5 additions & 2 deletions datafusion/physical-expr/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use crate::aggregate::row_accumulator::RowAccumulator;
use crate::PhysicalExpr;
use arrow::datatypes::Field;
use datafusion_common::Result;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::Accumulator;
use std::any::Any;
use std::fmt::Debug;
Expand Down Expand Up @@ -97,6 +97,9 @@ pub trait AggregateExpr: Send + Sync + Debug {
&self,
_start_index: usize,
) -> Result<Box<dyn RowAccumulator>> {
unreachable!()
Err(DataFusionError::NotImplemented(format!(
"RowAccumulator hasn't been implemented for {:?} yet",
self
)))
}
}
8 changes: 4 additions & 4 deletions datafusion/physical-expr/src/aggregate/sum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! Defines physical expressions that can evaluated at runtime during query execution
use std::any::Any;
use std::any::{type_name, Any};
use std::convert::TryFrom;
use std::sync::Arc;

Expand All @@ -31,7 +31,7 @@ use arrow::{
},
datatypes::Field,
};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_common::{downcast_value, DataFusionError, Result, ScalarValue};
use datafusion_expr::{Accumulator, AggregateState};

use crate::aggregate::row_accumulator::RowAccumulator;
Expand Down Expand Up @@ -144,7 +144,7 @@ impl SumAccumulator {
// returns the new value after sum with the new values, taking nullability into account
macro_rules! typed_sum_delta_batch {
($VALUES:expr, $ARRAYTYPE:ident, $SCALAR:ident) => {{
let array = $VALUES.as_any().downcast_ref::<$ARRAYTYPE>().unwrap();
let array = downcast_value!($VALUES, $ARRAYTYPE);
let delta = compute::sum(array);
ScalarValue::$SCALAR(delta)
}};
Expand All @@ -153,7 +153,7 @@ macro_rules! typed_sum_delta_batch {
// TODO implement this in arrow-rs with simd
// https://github.com/apache/arrow-rs/issues/1010
fn sum_decimal_batch(values: &ArrayRef, precision: u8, scale: u8) -> Result<ScalarValue> {
let array = values.as_any().downcast_ref::<Decimal128Array>().unwrap();
let array = downcast_value!(values, Decimal128Array);

if array.null_count() == array.len() {
return Ok(ScalarValue::Decimal128(None, precision, scale));
Expand Down
6 changes: 3 additions & 3 deletions datafusion/physical-expr/src/aggregate/sum_distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ impl Accumulator for DistinctSumAccumulator {

fn evaluate(&self) -> Result<ScalarValue> {
let mut sum_value = ScalarValue::try_from(&self.data_type)?;
self.hash_values.iter().for_each(|distinct_value| {
sum_value = sum::sum(&sum_value, distinct_value).unwrap()
});
for distinct_value in self.hash_values.iter() {
sum_value = sum::sum(&sum_value, distinct_value)?;
}
Ok(sum_value)
}
}
Expand Down
16 changes: 6 additions & 10 deletions datafusion/physical-expr/src/aggregate/variance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

//! Defines physical expressions that can evaluated at runtime during query execution
use std::any::Any;
use std::any::{type_name, Any};
use std::sync::Arc;

use crate::aggregate::stats::StatsType;
Expand All @@ -30,6 +30,7 @@ use arrow::{
datatypes::DataType,
datatypes::Field,
};
use datafusion_common::downcast_value;
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{Accumulator, AggregateState};
Expand Down Expand Up @@ -220,12 +221,7 @@ impl Accumulator for VarianceAccumulator {

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = &cast(&values[0], &DataType::Float64)?;
let arr = values
.as_any()
.downcast_ref::<Float64Array>()
.unwrap()
.iter()
.flatten();
let arr = downcast_value!(values, Float64Array).iter().flatten();

for value in arr {
let new_count = self.count + 1;
Expand All @@ -243,9 +239,9 @@ impl Accumulator for VarianceAccumulator {
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let counts = states[0].as_any().downcast_ref::<UInt64Array>().unwrap();
let means = states[1].as_any().downcast_ref::<Float64Array>().unwrap();
let m2s = states[2].as_any().downcast_ref::<Float64Array>().unwrap();
let counts = downcast_value!(states[0], UInt64Array);
let means = downcast_value!(states[1], Float64Array);
let m2s = downcast_value!(states[2], Float64Array);

for i in 0..counts.len() {
let c = counts.value(i);
Expand Down

0 comments on commit c5c1dae

Please sign in to comment.