diff --git a/dask_planner/src/sql.rs b/dask_planner/src/sql.rs index d9c453d35..5f5baaa34 100644 --- a/dask_planner/src/sql.rs +++ b/dask_planner/src/sql.rs @@ -148,32 +148,38 @@ impl ContextProvider for DaskSQLContext { match name { "year" => { - let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable); + let sig = generate_numeric_signatures(1); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } - "atan2" | "mod" => { - let sig = Signature::variadic( - vec![DataType::Float64, DataType::Float64], - Volatility::Immutable, - ); + "mod" => { + let sig = generate_numeric_signatures(2); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "cbrt" | "cot" | "degrees" | "radians" | "sign" | "truncate" => { - let sig = Signature::variadic(vec![DataType::Float64], Volatility::Immutable); + let sig = generate_numeric_signatures(1); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "rand" => { - let sig = Signature::variadic(vec![DataType::Int64], Volatility::Volatile); + let sig = Signature::one_of( + vec![ + TypeSignature::Exact(vec![]), + TypeSignature::Exact(vec![DataType::Int64]), + ], + Volatility::Immutable, + ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); } "rand_integer" => { - let sig = Signature::variadic( - vec![DataType::Int64, DataType::Int64], - Volatility::Volatile, + let sig = Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int64]), + TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]), + ], + Volatility::Immutable, ); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(ScalarUDF::new(name, &sig, &rtf, &fun))); @@ -231,39 +237,28 @@ impl ContextProvider for DaskSQLContext { match name { "every" => { - let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable); + let sig = generate_numeric_signatures(1); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Boolean))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "bit_and" | "bit_or" => { - let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable); + let sig = generate_numeric_signatures(1); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "single_value" => { - let sig = Signature::variadic(vec![DataType::Int64], Volatility::Immutable); + let sig = generate_numeric_signatures(1); let rtf: ReturnTypeFunction = Arc::new(|input_types| Ok(Arc::new(input_types[0].clone()))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "regr_count" => { - let sig = Signature::variadic( - vec![DataType::Float64, DataType::Float64], - Volatility::Immutable, - ); + let sig = generate_numeric_signatures(2); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Int64))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } "regr_syy" | "regr_sxx" => { - let sig = Signature::variadic( - vec![DataType::Float64, DataType::Float64], - Volatility::Immutable, - ); - let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); - return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); - } - "var_pop" => { - let sig = Signature::variadic(vec![DataType::Float64], Volatility::Immutable); + let sig = generate_numeric_signatures(2); let rtf: ReturnTypeFunction = Arc::new(|_| Ok(Arc::new(DataType::Float64))); return Some(Arc::new(AggregateUDF::new(name, &sig, &rtf, &acc, &st))); } @@ -576,3 +571,113 @@ impl PlanVisitor for OptimizablePlanVisitor { Ok(true) } } + +fn generate_numeric_signatures(n: i32) -> Signature { + // Generates all combinations of vectors of length n, + // i.e., the Cartesian product + let datatypes = vec![ + DataType::Int8, + DataType::Int16, + DataType::Int32, + DataType::Int64, + DataType::UInt8, + DataType::UInt16, + DataType::UInt32, + DataType::UInt64, + DataType::Float16, + DataType::Float32, + DataType::Float64, + ]; + let mut cartesian_setup = vec![]; + // cartesian_setup = [datatypes, datatypes] when n == 2, etc. + for _ in 0..n { + cartesian_setup.push(datatypes.clone()); + } + + let mut exact_vector = vec![]; + let mut datatypes_iter = cartesian_setup.iter(); + // First pass + if let Some(first_iter) = datatypes_iter.next() { + for datatype in first_iter { + exact_vector.push(vec![datatype.clone()]); + } + } + // Generate list of lists with length n + for iter in datatypes_iter { + let mut outer_temp = vec![]; + for outer_datatype in exact_vector { + for inner_datatype in iter { + let mut inner_temp = outer_datatype.clone(); + inner_temp.push(inner_datatype.clone()); + outer_temp.push(inner_temp); + } + } + exact_vector = outer_temp; + } + + // Create vector of TypeSignatures + let mut one_of_vector = vec![]; + for vector in exact_vector.iter() { + one_of_vector.push(TypeSignature::Exact(vector.clone())); + } + + Signature::one_of(one_of_vector.clone(), Volatility::Immutable) +} + +#[allow(dead_code)] +fn generate_signatures(cartesian_setup: Vec>) -> Signature { + let mut exact_vector = vec![]; + let mut datatypes_iter = cartesian_setup.iter(); + // First pass + if let Some(first_iter) = datatypes_iter.next() { + for datatype in first_iter { + exact_vector.push(vec![datatype.clone()]); + } + } + // Generate the Cartesian product + for iter in datatypes_iter { + let mut outer_temp = vec![]; + for outer_datatype in exact_vector { + for inner_datatype in iter { + let mut inner_temp = outer_datatype.clone(); + inner_temp.push(inner_datatype.clone()); + outer_temp.push(inner_temp); + } + } + exact_vector = outer_temp; + } + + // Create vector of TypeSignatures + let mut one_of_vector = vec![]; + for vector in exact_vector.iter() { + one_of_vector.push(TypeSignature::Exact(vector.clone())); + } + + Signature::one_of(one_of_vector.clone(), Volatility::Immutable) +} + +#[cfg(test)] +mod test { + use arrow::datatypes::DataType; + use datafusion_expr::{Signature, TypeSignature, Volatility}; + + use crate::sql::generate_signatures; + + #[test] + fn test_generate_signatures() { + let sig = generate_signatures(vec![ + vec![DataType::Int64, DataType::Float64], + vec![DataType::Utf8, DataType::Int64], + ]); + let expected = Signature::one_of( + vec![ + TypeSignature::Exact(vec![DataType::Int64, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Int64, DataType::Int64]), + TypeSignature::Exact(vec![DataType::Float64, DataType::Utf8]), + TypeSignature::Exact(vec![DataType::Float64, DataType::Int64]), + ], + Volatility::Immutable, + ); + assert_eq!(sig, expected); + } +}