Skip to content

Commit

Permalink
Replace variadic with exact where appropriate (#885)
Browse files Browse the repository at this point in the history
* replace variadic with exact where appropriate

* create generate_numeric_signatures function

* lint, rand/rand_integer, atan2/var_pop

* style fix

* generate_signatures and test

* style fix
  • Loading branch information
sarahyurick authored Nov 2, 2022
1 parent 1d6b737 commit fd68c28
Showing 1 changed file with 132 additions and 27 deletions.
159 changes: 132 additions & 27 deletions dask_planner/src/sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,32 +148,38 @@ impl ContextProvider for DaskSQLContext {

match name {
"year" => {
let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable);
let sig = generate_numeric_signatures(1);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));
return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));
}
"atan2" | "mod" => {
let sig = Signature::variadic(
vec![DataType::Float64, DataType::Float64],
Volatility::Immutable,
);
"mod" => {
let sig = generate_numeric_signatures(2);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));
return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));
}
"cbrt" | "cot" | "degrees" | "radians" | "sign" | "truncate" => {
let sig = Signature::variadic(vec![DataType::Float64], Volatility::Immutable);
let sig = generate_numeric_signatures(1);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));
return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));
}
"rand" => {
let sig = Signature::variadic(vec![DataType::Int64], Volatility::Volatile);
let sig = Signature::one_of(
vec![
TypeSignature::Exact(vec![]),
TypeSignature::Exact(vec![DataType::Int64]),
],
Volatility::Immutable,
);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));
return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));
}
"rand_integer" => {
let sig = Signature::variadic(
vec![DataType::Int64, DataType::Int64],
Volatility::Volatile,
let sig = Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Int64]),
TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]),
],
Volatility::Immutable,
);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));
return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun)));
Expand Down Expand Up @@ -231,39 +237,28 @@ impl ContextProvider for DaskSQLContext {

match name {
"every" => {
let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable);
let sig = generate_numeric_signatures(1);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean)));
return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st)));
}
"bit_and" | "bit_or" => {
let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable);
let sig = generate_numeric_signatures(1);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));
return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st)));
}
"single_value" => {
let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable);
let sig = generate_numeric_signatures(1);
let rtf: ReturnTypeFunction =
Arc::new(|input_types| Ok(Arc::new(input_types[0].clone())));
return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st)));
}
"regr_count" => {
let sig = Signature::variadic(
vec![DataType::Float64, DataType::Float64],
Volatility::Immutable,
);
let sig = generate_numeric_signatures(2);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64)));
return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st)));
}
"regr_syy" | "regr_sxx" => {
let sig = Signature::variadic(
vec![DataType::Float64, DataType::Float64],
Volatility::Immutable,
);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));
return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st)));
}
"var_pop" => {
let sig = Signature::variadic(vec![DataType::Float64], Volatility::Immutable);
let sig = generate_numeric_signatures(2);
let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64)));
return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st)));
}
Expand Down Expand Up @@ -576,3 +571,113 @@ impl PlanVisitor for OptimizablePlanVisitor {
Ok(true)
}
}

fn generate_numeric_signatures(n: i32) -> Signature {
// Generates all combinations of vectors of length n,
// i.e., the Cartesian product
let datatypes = vec![
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
DataType::Float16,
DataType::Float32,
DataType::Float64,
];
let mut cartesian_setup = vec![];
// cartesian_setup = [datatypes, datatypes] when n == 2, etc.
for _ in 0..n {
cartesian_setup.push(datatypes.clone());
}

let mut exact_vector = vec![];
let mut datatypes_iter = cartesian_setup.iter();
// First pass
if let Some(first_iter) = datatypes_iter.next() {
for datatype in first_iter {
exact_vector.push(vec![datatype.clone()]);
}
}
// Generate list of lists with length n
for iter in datatypes_iter {
let mut outer_temp = vec![];
for outer_datatype in exact_vector {
for inner_datatype in iter {
let mut inner_temp = outer_datatype.clone();
inner_temp.push(inner_datatype.clone());
outer_temp.push(inner_temp);
}
}
exact_vector = outer_temp;
}

// Create vector of TypeSignatures
let mut one_of_vector = vec![];
for vector in exact_vector.iter() {
one_of_vector.push(TypeSignature::Exact(vector.clone()));
}

Signature::one_of(one_of_vector.clone(), Volatility::Immutable)
}

#[allow(dead_code)]
fn generate_signatures(cartesian_setup: Vec<Vec<DataType>>) -> Signature {
let mut exact_vector = vec![];
let mut datatypes_iter = cartesian_setup.iter();
// First pass
if let Some(first_iter) = datatypes_iter.next() {
for datatype in first_iter {
exact_vector.push(vec![datatype.clone()]);
}
}
// Generate the Cartesian product
for iter in datatypes_iter {
let mut outer_temp = vec![];
for outer_datatype in exact_vector {
for inner_datatype in iter {
let mut inner_temp = outer_datatype.clone();
inner_temp.push(inner_datatype.clone());
outer_temp.push(inner_temp);
}
}
exact_vector = outer_temp;
}

// Create vector of TypeSignatures
let mut one_of_vector = vec![];
for vector in exact_vector.iter() {
one_of_vector.push(TypeSignature::Exact(vector.clone()));
}

Signature::one_of(one_of_vector.clone(), Volatility::Immutable)
}

#[cfg(test)]
mod test {
use arrow::datatypes::DataType;
use datafusion_expr::{Signature, TypeSignature, Volatility};

use crate::sql::generate_signatures;

#[test]
fn test_generate_signatures() {
let sig = generate_signatures(vec![
vec![DataType::Int64, DataType::Float64],
vec![DataType::Utf8, DataType::Int64],
]);
let expected = Signature::one_of(
vec![
TypeSignature::Exact(vec![DataType::Int64, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]),
TypeSignature::Exact(vec![DataType::Float64, DataType::Utf8]),
TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]),
],
Volatility::Immutable,
);
assert_eq!(sig, expected);
}
}

0 comments on commit fd68c28

Please sign in to comment.