From 215f8db8eca4f0377d222d75c5572fb1c6a35499 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies Date: Wed, 14 Dec 2022 22:24:59 +0000 Subject: [PATCH] DataFrame owned SessionState (#4617) --- .../examples/custom_datasource.rs | 2 +- datafusion/core/src/dataframe.rs | 58 +++++-------------- datafusion/core/src/execution/context.rs | 16 ++--- 3 files changed, 25 insertions(+), 51 deletions(-) diff --git a/datafusion-examples/examples/custom_datasource.rs b/datafusion-examples/examples/custom_datasource.rs index db4fed494ca1..68e8f5a54630 100644 --- a/datafusion-examples/examples/custom_datasource.rs +++ b/datafusion-examples/examples/custom_datasource.rs @@ -69,7 +69,7 @@ async fn search_accounts( )? .build()?; - let mut dataframe = DataFrame::new(ctx.state, logical_plan) + let mut dataframe = DataFrame::new(ctx.state(), logical_plan) .select_columns(&["id", "bank_account"])?; if let Some(f) = filter { diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index 5e615be607c3..a6b43d3f0548 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -21,7 +21,6 @@ use std::any::Any; use std::sync::Arc; use async_trait::async_trait; -use parking_lot::RwLock; use parquet::file::properties::WriterProperties; use datafusion_common::{Column, DFSchema}; @@ -74,13 +73,13 @@ use crate::prelude::SessionContext; /// ``` #[derive(Debug, Clone)] pub struct DataFrame { - session_state: Arc>, + session_state: SessionState, plan: LogicalPlan, } impl DataFrame { /// Create a new Table based on an existing logical plan - pub fn new(session_state: Arc>, plan: LogicalPlan) -> Self { + pub fn new(session_state: SessionState, plan: LogicalPlan) -> Self { Self { session_state, plan, @@ -88,26 +87,8 @@ impl DataFrame { } /// Create a physical plan - pub async fn create_physical_plan(self) -> Result> { - // this function is copied from SessionContext function of the - // same name - let state_cloned = { - let mut state = self.session_state.write(); - state.execution_props.start_execution(); - - // We need to clone `state` to release the lock that is not `Send`. We could - // make the lock `Send` by using `tokio::sync::Mutex`, but that would require to - // propagate async even to the `LogicalPlan` building methods. - // Cloning `state` here is fine as we then pass it as immutable `&state`, which - // means that we avoid write consistency issues as the cloned version will not - // be written to. As for eventual modifications that would be applied to the - // original state after it has been cloned, they will not be picked up by the - // clone but that is okay, as it is equivalent to postponing the state update - // by keeping the lock until the end of the function scope. - state.clone() - }; - - state_cloned.create_physical_plan(&self.plan).await + pub async fn create_physical_plan(&self) -> Result> { + self.session_state.create_physical_plan(&self.plan).await } /// Filter the DataFrame by column. Returns a new DataFrame only containing the @@ -437,8 +418,7 @@ impl DataFrame { } fn task_ctx(&self) -> TaskContext { - let lock = self.session_state.read(); - TaskContext::from(&*lock) + TaskContext::from(&self.session_state) } /// Executes this DataFrame and returns a stream over a single partition @@ -527,8 +507,7 @@ impl DataFrame { /// Return the optimized logical plan represented by this DataFrame. pub fn to_logical_plan(self) -> Result { // Optimize the plan first for better UX - let state = self.session_state.read().clone(); - state.optimize(&self.plan) + self.session_state.optimize(&self.plan) } /// Return a DataFrame with the explanation of its plan so far. @@ -567,9 +546,8 @@ impl DataFrame { /// # Ok(()) /// # } /// ``` - pub fn registry(&self) -> Arc { - let registry = self.session_state.read().clone(); - Arc::new(registry) + pub fn registry(&self) -> &dyn FunctionRegistry { + &self.session_state } /// Calculate the intersection of two [`DataFrame`]s. The two [`DataFrame`]s must have exactly the same schema @@ -621,9 +599,8 @@ impl DataFrame { /// Write a `DataFrame` to a CSV file. pub async fn write_csv(self, path: &str) -> Result<()> { - let state = self.session_state.read().clone(); let plan = self.create_physical_plan().await?; - plan_to_csv(&state, plan, path).await + plan_to_csv(&self.session_state, plan, path).await } /// Write a `DataFrame` to a Parquet file. @@ -632,16 +609,14 @@ impl DataFrame { path: &str, writer_properties: Option, ) -> Result<()> { - let state = self.session_state.read().clone(); let plan = self.create_physical_plan().await?; - plan_to_parquet(&state, plan, path, writer_properties).await + plan_to_parquet(&self.session_state, plan, path, writer_properties).await } /// Executes a query and writes the results to a partitioned JSON file. pub async fn write_json(self, path: impl AsRef) -> Result<()> { - let state = self.session_state.read().clone(); let plan = self.create_physical_plan().await?; - plan_to_json(&state, plan, path).await + plan_to_json(&self.session_state, plan, path).await } /// Add an additional column to the DataFrame. @@ -747,7 +722,7 @@ impl DataFrame { /// # } /// ``` pub async fn cache(self) -> Result { - let context = SessionContext::with_state(self.session_state.read().clone()); + let context = SessionContext::with_state(self.session_state.clone()); let mem_table = MemTable::try_new( SchemaRef::from(self.schema().clone()), self.collect_partitioned().await?, @@ -1029,9 +1004,8 @@ mod tests { // build query with a UDF using DataFrame API let df = ctx.table("aggregate_test_100")?; - let f = df.registry(); - - let df = df.select(vec![f.udf("my_fn")?.call(vec![col("c12")])])?; + let expr = df.registry().udf("my_fn")?.call(vec![col("c12")]); + let df = df.select(vec![expr])?; // build query using SQL let sql_plan = @@ -1088,7 +1062,7 @@ mod tests { async fn register_table() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c12"])?; let ctx = SessionContext::new(); - let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone()); + let df_impl = DataFrame::new(ctx.state(), df.plan.clone()); // register a dataframe as a table ctx.register_table("test_table", Arc::new(df_impl.clone()))?; @@ -1180,7 +1154,7 @@ mod tests { async fn with_column() -> Result<()> { let df = test_table().await?.select_columns(&["c1", "c2", "c3"])?; let ctx = SessionContext::new(); - let df_impl = DataFrame::new(ctx.state.clone(), df.plan.clone()); + let df_impl = DataFrame::new(ctx.state(), df.plan.clone()); let df = df_impl .filter(col("c2").eq(lit(3)).and(col("c1").eq(lit("a"))))? diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index f98e21dd1dd5..359b1c48b5e1 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -273,7 +273,7 @@ impl SessionContext { (false, true, Ok(_)) => { self.deregister_table(&name)?; let schema = Arc::new(input.schema().as_ref().into()); - let physical = DataFrame::new(self.state.clone(), input); + let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; let table = Arc::new(MemTable::try_new(schema, batches)?); @@ -286,7 +286,7 @@ impl SessionContext { )), (_, _, Err(_)) => { let schema = Arc::new(input.schema().as_ref().into()); - let physical = DataFrame::new(self.state.clone(), input); + let physical = DataFrame::new(self.state(), input); let batches: Vec<_> = physical.collect_partitioned().await?; let table = Arc::new(MemTable::try_new(schema, batches)?); @@ -475,14 +475,14 @@ impl SessionContext { } } - plan => Ok(DataFrame::new(self.state.clone(), plan)), + plan => Ok(DataFrame::new(self.state(), plan)), } } // return an empty dataframe fn return_empty_dataframe(&self) -> Result { let plan = LogicalPlanBuilder::empty(false).build()?; - Ok(DataFrame::new(self.state.clone(), plan)) + Ok(DataFrame::new(self.state(), plan)) } async fn create_external_table( @@ -661,7 +661,7 @@ impl SessionContext { /// Creates an empty DataFrame. pub fn read_empty(&self) -> Result { Ok(DataFrame::new( - self.state.clone(), + self.state(), LogicalPlanBuilder::empty(true).build()?, )) } @@ -716,7 +716,7 @@ impl SessionContext { /// Creates a [`DataFrame`] for reading a custom [`TableProvider`]. pub fn read_table(&self, provider: Arc) -> Result { Ok(DataFrame::new( - self.state.clone(), + self.state(), LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)? .build()?, )) @@ -726,7 +726,7 @@ impl SessionContext { pub fn read_batch(&self, batch: RecordBatch) -> Result { let provider = MemTable::try_new(batch.schema(), vec![vec![batch]])?; Ok(DataFrame::new( - self.state.clone(), + self.state(), LogicalPlanBuilder::scan( UNNAMED_TABLE, provider_as_source(Arc::new(provider)), @@ -946,7 +946,7 @@ impl SessionContext { None, )? .build()?; - Ok(DataFrame::new(self.state.clone(), plan)) + Ok(DataFrame::new(self.state(), plan)) } /// Return a [`TabelProvider`] for the specified table.