Skip to content

Commit

Permalink
fix: remove use of deprecated make_scalar_function
Browse files Browse the repository at this point in the history
`make_scalar_function` has been deprecated since v36 [0].
It is being removed from the public api in v43 [1].

[0]: apache/datafusion#8878
[1]: apache/datafusion#12505
  • Loading branch information
Michael-J-Ward committed Oct 11, 2024
1 parent cdec202 commit 7bab0a3
Showing 1 changed file with 36 additions and 24 deletions.
60 changes: 36 additions & 24 deletions src/udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,39 +24,51 @@ use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::FromPyArrow;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::error::DataFusionError;
use datafusion::logical_expr::create_udf;
use datafusion::logical_expr::function::ScalarFunctionImplementation;
use datafusion::logical_expr::ScalarUDF;
use datafusion::logical_expr::{create_udf, ColumnarValue};

use crate::expr::PyExpr;
use crate::utils::parse_volatility;

/// Create a Rust callable function fr a python function that expects pyarrow arrays
fn pyarrow_function_to_rust(
func: PyObject,
) -> impl Fn(&[ArrayRef]) -> Result<ArrayRef, DataFusionError> {
move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
Python::with_gil(|py| {
// 1. cast args to Pyarrow arrays
let py_args = args
.iter()
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
let py_args = PyTuple::new_bound(py, py_args);

// 2. call function
let value = func
.call_bound(py, py_args, None)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;

// 3. cast to arrow::array::Array
let array_data = ArrayData::from_pyarrow_bound(value.bind(py)).unwrap();
Ok(make_array(array_data))
})
}
}

/// Create a DataFusion's UDF implementation from a python function
/// that expects pyarrow arrays. This is more efficient as it performs
/// a zero-copy of the contents.
fn to_rust_function(func: PyObject) -> ScalarFunctionImplementation {
#[allow(deprecated)]
datafusion::physical_plan::functions::make_scalar_function(
move |args: &[ArrayRef]| -> Result<ArrayRef, DataFusionError> {
Python::with_gil(|py| {
// 1. cast args to Pyarrow arrays
let py_args = args
.iter()
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
let py_args = PyTuple::new_bound(py, py_args);

// 2. call function
let value = func
.call_bound(py, py_args, None)
.map_err(|e| DataFusionError::Execution(format!("{e:?}")))?;
fn to_scalar_function_impl(func: PyObject) -> ScalarFunctionImplementation {
// Make the python function callable from rust
let pyarrow_func = pyarrow_function_to_rust(func);

// 3. cast to arrow::array::Array
let array_data = ArrayData::from_pyarrow_bound(value.bind(py)).unwrap();
Ok(make_array(array_data))
})
},
)
// Convert input/output from datafusion ColumnarValue to arrow arrays
Arc::new(move |args: &[ColumnarValue]| {
let array_refs = ColumnarValue::values_to_arrays(args)?;
let array_result = pyarrow_func(&array_refs)?;
Ok(array_result.into())
})
}

/// Represents a PyScalarUDF
Expand All @@ -82,7 +94,7 @@ impl PyScalarUDF {
input_types.0,
Arc::new(return_type.0),
parse_volatility(volatility)?,
to_rust_function(func),
to_scalar_function_impl(func),
);
Ok(Self { function })
}
Expand Down

0 comments on commit 7bab0a3

Please sign in to comment.