diff --git a/src/udf.rs b/src/udf.rs index 7d5db2f96..b35ad0e06 100644 --- a/src/udf.rs +++ b/src/udf.rs @@ -24,39 +24,51 @@ use datafusion::arrow::datatypes::DataType; use datafusion::arrow::pyarrow::FromPyArrow; use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow}; use datafusion::error::DataFusionError; -use datafusion::logical_expr::create_udf; use datafusion::logical_expr::function::ScalarFunctionImplementation; use datafusion::logical_expr::ScalarUDF; +use datafusion::logical_expr::{create_udf, ColumnarValue}; use crate::expr::PyExpr; use crate::utils::parse_volatility; +/// Create a Rust callable function fr a python function that expects pyarrow arrays +fn pyarrow_function_to_rust( + func: PyObject, +) -> impl Fn(&[ArrayRef]) -> Result { + move |args: &[ArrayRef]| -> Result { + Python::with_gil(|py| { + // 1. cast args to Pyarrow arrays + let py_args = args + .iter() + .map(|arg| arg.into_data().to_pyarrow(py).unwrap()) + .collect::>(); + let py_args = PyTuple::new_bound(py, py_args); + + // 2. call function + let value = func + .call_bound(py, py_args, None) + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + + // 3. cast to arrow::array::Array + let array_data = ArrayData::from_pyarrow_bound(value.bind(py)).unwrap(); + Ok(make_array(array_data)) + }) + } +} + /// Create a DataFusion's UDF implementation from a python function /// that expects pyarrow arrays. This is more efficient as it performs /// a zero-copy of the contents. -fn to_rust_function(func: PyObject) -> ScalarFunctionImplementation { - #[allow(deprecated)] - datafusion::physical_plan::functions::make_scalar_function( - move |args: &[ArrayRef]| -> Result { - Python::with_gil(|py| { - // 1. cast args to Pyarrow arrays - let py_args = args - .iter() - .map(|arg| arg.into_data().to_pyarrow(py).unwrap()) - .collect::>(); - let py_args = PyTuple::new_bound(py, py_args); - - // 2. call function - let value = func - .call_bound(py, py_args, None) - .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; +fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation { + // Make the python function callable from rust + let pyarrow_func = pyarrow_function_to_rust(func); - // 3. cast to arrow::array::Array - let array_data = ArrayData::from_pyarrow_bound(value.bind(py)).unwrap(); - Ok(make_array(array_data)) - }) - }, - ) + // Convert input/output from datafusion ColumnarValue to arrow arrays + Arc::new(move |args: &[ColumnarValue]| { + let array_refs = ColumnarValue::values_to_arrays(args)?; + let array_result = pyarrow_func(&array_refs)?; + Ok(array_result.into()) + }) } /// Represents a PyScalarUDF @@ -82,7 +94,7 @@ impl PyScalarUDF { input_types.0, Arc::new(return_type.0), parse_volatility(volatility)?, - to_rust_function(func), + to_scalar_function_impl(func), ); Ok(Self { function }) }