Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add array_distance function #12211

Merged
merged 10 commits into from
Aug 29, 2024
218 changes: 218 additions & 0 deletions datafusion/functions-aggregate/src/distance.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! [`Distance`]: Euclidean distance aggregations.

use std::any::Any;
use std::fmt::Debug;

use arrow::compute::kernels::cast;
use arrow::compute::{and, filter, is_not_null};
use arrow::{
array::{ArrayRef, Float64Array},
datatypes::{DataType, Field},
};

use datafusion_common::{
downcast_value, plan_err, unwrap_or_internal_err, DataFusionError, Result,
ScalarValue,
};
use datafusion_expr::{
function::{AccumulatorArgs, StateFieldsArgs},
type_coercion::aggregates::NUMERICS,
utils::format_state_name,
Accumulator, AggregateUDFImpl, Signature, Volatility,
};

make_udaf_expr_and_func!(
Distance,
dis,
y x,
"Distance between two numeric values.",
austin362667 marked this conversation as resolved.
Show resolved Hide resolved
dis_udaf
);

#[derive(Debug)]
pub struct Distance {
signature: Signature,
}

impl Default for Distance {
fn default() -> Self {
Self::new()
}
}

impl Distance {
pub fn new() -> Self {
Self {
signature: Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable),
austin362667 marked this conversation as resolved.
Show resolved Hide resolved
}
}
}

impl AggregateUDFImpl for Distance {
fn as_any(&self) -> &dyn Any {
self
}

fn name(&self) -> &str {
"distance"
}

fn signature(&self) -> &Signature {
&self.signature
}

fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
if !arg_types[0].is_numeric() {
return plan_err!("Distance requires numeric input types");
}
austin362667 marked this conversation as resolved.
Show resolved Hide resolved

Ok(DataType::Float64)
}
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
Ok(Box::new(DistanceAccumulator::try_new()?))
}

fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<Field>> {
let name = args.name;
Ok(vec![Field::new(
format_state_name(name, "sum_of_squares"),
DataType::Float64,
true,
)])
}
}

/// An accumulator to compute distance of two numeric columns
#[derive(Debug)]
pub struct DistanceAccumulator {
sum_of_squares: f64,
}

impl DistanceAccumulator {
/// Creates a new `DistanceAccumulator`
pub fn try_new() -> Result<Self> {
Ok(Self {
sum_of_squares: 0_f64,
})
}
}

impl Accumulator for DistanceAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Ok(vec![ScalarValue::from(self.sum_of_squares)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
let values1 = filter(&values[0], &mask)?;
let values2 = filter(&values[1], &mask)?;

vec![values1, values2]
} else {
values.to_vec()
};

let values1 = &cast(&values[0], &DataType::Float64)?;
let values2 = &cast(&values[1], &DataType::Float64)?;

let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();
austin362667 marked this conversation as resolved.
Show resolved Hide resolved

for i in 0..values1.len() {
let value1 = if values1.is_valid(i) {
arr1.next()
} else {
None
};
let value2 = if values2.is_valid(i) {
arr2.next()
} else {
None
};
if value1.is_none() || value2.is_none() {
continue;
}

let value1 = unwrap_or_internal_err!(value1);
let value2 = unwrap_or_internal_err!(value2);

self.sum_of_squares += (value1 - value2).powi(2);
}
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let values = if values[0].null_count() != 0 || values[1].null_count() != 0 {
let mask = and(&is_not_null(&values[0])?, &is_not_null(&values[1])?)?;
let values1 = filter(&values[0], &mask)?;
let values2 = filter(&values[1], &mask)?;

vec![values1, values2]
} else {
values.to_vec()
};

let values1 = &cast(&values[0], &DataType::Float64)?;
let values2 = &cast(&values[1], &DataType::Float64)?;
let mut arr1 = downcast_value!(values1, Float64Array).iter().flatten();
let mut arr2 = downcast_value!(values2, Float64Array).iter().flatten();

for i in 0..values1.len() {
let value1 = if values1.is_valid(i) {
arr1.next()
} else {
None
};
let value2 = if values2.is_valid(i) {
arr2.next()
} else {
None
};

if value1.is_none() || value2.is_none() {
continue;
}

let value1 = unwrap_or_internal_err!(value1);
let value2 = unwrap_or_internal_err!(value2);

self.sum_of_squares -= (value1 - value2).powi(2);
}

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
let list_sum_of_squares = downcast_value!(states[0], Float64Array);
for i in 0..list_sum_of_squares.len() {
self.sum_of_squares += list_sum_of_squares.value(i);
}
Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Float64(Some(self.sum_of_squares.sqrt())))
}

fn size(&self) -> usize {
std::mem::size_of_val(self)
}
}
3 changes: 3 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ pub mod array_agg;
pub mod correlation;
pub mod count;
pub mod covariance;
pub mod distance;
pub mod first_last;
pub mod hyperloglog;
pub mod median;
Expand Down Expand Up @@ -107,6 +108,7 @@ pub mod expr_fn {
pub use super::count::count_distinct;
pub use super::covariance::covar_pop;
pub use super::covariance::covar_samp;
pub use super::distance::dis;
pub use super::first_last::first_value;
pub use super::first_last::last_value;
pub use super::grouping::grouping;
Expand Down Expand Up @@ -138,6 +140,7 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
covariance::covar_samp_udaf(),
covariance::covar_pop_udaf(),
correlation::corr_udaf(),
distance::dis_udaf(),
sum::sum_udaf(),
min_max::max_udaf(),
min_max::min_udaf(),
Expand Down
44 changes: 44 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,50 @@ from data
----
1

# csv_query_distance
query R
SELECT distance(c2, c12) FROM aggregate_test_100
----
27.565541154252

# single_row_query_distance
query R
select distance(sq.column1, sq.column2) from (values (1.1, 2.2)) as sq
----
1.1

# all_nulls_query_distance
query R
with data as (
select null::int as f, null::int as b
union all
select null::int as f, null::int as b
)
select distance(f, b)
from data
----
0
austin362667 marked this conversation as resolved.
Show resolved Hide resolved

# distance_query_with_nulls
query R
with data as (
select 1 as f, 4 as b
union all
select null as f, 99 as b
union all
select 2 as f, 5 as b
union all
select 98 as f, null as b
union all
select 3 as f, 6 as b
union all
select null as f, null as b
)
select distance(f, b)
from data
----
5.196152422707

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

> select distance(sq.column1, sq.column2) from (values (NULL, 2), (0,0)) as sq;
+---------------------------------+
| distance(sq.column1,sq.column2) |
+---------------------------------+
| 0.0                             |
+---------------------------------+

I prefer it to be NULL

Copy link
Contributor Author

@austin362667 austin362667 Aug 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure let's do it this way~

# csv_query_variance_1
query R
SELECT var_pop(c2) FROM aggregate_test_100
Expand Down