From b9bf00ef97f30546585c6a729f7827cbc4b14c1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= Date: Thu, 29 Feb 2024 20:10:37 +0000 Subject: [PATCH] Address PR comments (factory interface) --- datafusion/core/src/execution/context/mod.rs | 71 +++++++++++++------ .../user_defined_scalar_functions.rs | 41 ++++++----- 2 files changed, 72 insertions(+), 40 deletions(-) diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 41bdcdae2b4ed..2847a3f9f4203 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -798,28 +798,48 @@ impl SessionContext { } async fn create_function(&self, stmt: CreateFunction) -> Result { - let function_factory = self.state.read().function_factory.clone(); + let function = { + let state = self.state.read().clone(); + let function_factory = &state.function_factory; + + match function_factory { + Some(f) => f.create(state.config(), stmt).await?, + _ => Err(DataFusionError::Configuration( + "Function factory has not been configured".into(), + ))?, + } + }; - match function_factory { - Some(f) => f.create(self.state.clone(), stmt).await?, - None => Err(DataFusionError::Configuration( - "Function factory has not been configured".into(), - ))?, + match function { + RegisterFunction::Scalar(f) => { + self.state.write().register_udf(f)?; + } + RegisterFunction::Aggregate(f) => { + self.state.write().register_udaf(f)?; + } + RegisterFunction::Window(f) => { + self.state.write().register_udwf(f)?; + } + RegisterFunction::Table(name, f) => self.register_udtf(&name, f), }; self.return_empty_dataframe() } async fn drop_function(&self, stmt: DropFunction) -> Result { - let function_factory = self.state.read().function_factory.clone(); - - match function_factory { - Some(f) => f.remove(self.state.clone(), stmt).await?, - None => Err(DataFusionError::Configuration( - "Function factory has not been configured".into(), - ))?, + let _function = { + let state = self.state.read().clone(); + let function_factory = &state.function_factory; + + match function_factory { + Some(f) => f.remove(state.config(), stmt).await?, + None => Err(DataFusionError::Configuration( + "Function factory has not been configured".into(), + ))?, + } }; + // TODO: Once we have unregister UDF we need to implement it here self.return_empty_dataframe() } @@ -1289,27 +1309,36 @@ impl QueryPlanner for DefaultQueryPlanner { /// ``` #[async_trait] pub trait FunctionFactory: Sync + Send { - // TODO: I don't like having RwLock Leaking here, who ever implements it - // has to depend ot `parking_lot`. I'f we expose &mut SessionState it - // may keep lock of too long. // - // Not sure if there is better approach. + // This api holds a read lock for state // /// Handles creation of user defined function specified in [CreateFunction] statement async fn create( &self, - state: Arc>, + state: &SessionConfig, statement: CreateFunction, - ) -> Result<()>; + ) -> Result; /// Drops user defined function from [SessionState] // Naming it `drop`` would make more sense but its already occupied in rust async fn remove( &self, - state: Arc>, + state: &SessionConfig, statement: DropFunction, - ) -> Result<()>; + ) -> Result; +} + +/// Type of function to create +pub enum RegisterFunction { + /// Scalar user defined function + Scalar(Arc), + /// Aggregate user defined function + Aggregate(Arc), + /// Window user defined function + Window(Arc), + /// Table user defined function + Table(String, Arc), } /// Execution context for registering data sources and executing queries. /// See [`SessionContext`] for a higher level API. diff --git a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs index 4f9b21573e88f..fe07f614a33ca 100644 --- a/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs +++ b/datafusion/core/tests/user_defined/user_defined_scalar_functions.rs @@ -21,7 +21,7 @@ use arrow_array::{ }; use arrow_schema::DataType::Float64; use arrow_schema::{DataType, Field, Schema}; -use datafusion::execution::context::{FunctionFactory, SessionState}; +use datafusion::execution::context::{FunctionFactory, RegisterFunction, SessionState}; use datafusion::prelude::*; use datafusion::{execution::registry::FunctionRegistry, test_util}; use datafusion_common::cast::as_float64_array; @@ -34,7 +34,7 @@ use datafusion_expr::{ create_udaf, create_udf, Accumulator, ColumnarValue, CreateFunction, DropFunction, ExprSchemable, LogicalPlanBuilder, ScalarUDF, ScalarUDFImpl, Signature, Volatility, }; -use parking_lot::{Mutex, RwLock}; +use parking_lot::Mutex; use rand::{thread_rng, Rng}; use std::any::Any; use std::iter; @@ -636,9 +636,9 @@ impl FunctionFactory for MockFunctionFactory { #[allow(clippy::type_complexity, clippy::type_repetition_in_bounds)] async fn create( &self, - state: Arc>, + _config: &SessionConfig, statement: CreateFunction, - ) -> datafusion::error::Result<()> { + ) -> datafusion::error::Result { // this function is a mock for testing // `CreateFunction` should be used to derive this function @@ -675,22 +675,26 @@ impl FunctionFactory for MockFunctionFactory { // it has been parsed *self.captured_expr.lock() = statement.params.return_; - // we may need other infrastructure provided by state, for example: - // state.config().get_extension() - - // register mock udf for testing - state.write().register_udf(mock_udf.into())?; - Ok(()) + Ok(RegisterFunction::Scalar(Arc::new(mock_udf))) } async fn remove( &self, - _state: Arc>, + _config: &SessionConfig, _statement: DropFunction, - ) -> datafusion::error::Result<()> { - // at the moment state does not support unregister - // ignoring for now - Ok(()) + ) -> datafusion::error::Result { + + // TODO: I don't like that remove returns RegisterFunction + // we have to keep two states in FunctionFactory iml and + // SessionState + // + // It would be better to return (function_name, function type) tuple + + // at the moment state does not support unregister user defined functions + + Err(DataFusionError::NotImplemented( + "remove function has not been implemented".into(), + )) } } @@ -722,15 +726,14 @@ async fn create_scalar_function_from_sql_statement() { .await .unwrap(); - // sql expression should be convert to datafusion expression - // in this case + // check if we sql expr has been converted to datafusion expr let captured_expression = function_factory.captured_expr.lock().clone().unwrap(); // is there some better way to test this assert_eq!("$1 + $2", captured_expression.to_string()); - println!("{:?}", captured_expression); - ctx.sql("drop function better_add").await.unwrap(); + // no support at the moment + // ctx.sql("drop function better_add").await.unwrap(); } fn create_udf_context() -> SessionContext {