diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs index c46badd64fa8..4e49bff09d9d 100644 --- a/benchmarks/src/bin/tpch.rs +++ b/benchmarks/src/bin/tpch.rs @@ -178,9 +178,9 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result>, // filters and limit can be used here to inject some push-down operations if needed _filters: &[Expr], diff --git a/datafusion/core/benches/sort_limit_query_sql.rs b/datafusion/core/benches/sort_limit_query_sql.rs index d1f253a982a5..198eb941f14d 100644 --- a/datafusion/core/benches/sort_limit_query_sql.rs +++ b/datafusion/core/benches/sort_limit_query_sql.rs @@ -89,8 +89,8 @@ fn create_context() -> Arc> { let ctx = SessionContext::new(); ctx.state.write().config.target_partitions = 1; - let task_ctx = ctx.task_ctx(); - let mem_table = MemTable::load(Arc::new(csv.await), Some(partitions), task_ctx) + let table_provider = Arc::new(csv.await); + let mem_table = MemTable::load(table_provider, Some(partitions), &ctx.state()) .await .unwrap(); ctx.register_table("aggregate_test_100", Arc::new(mem_table)) diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index c8e0eef30316..7692a187ec6c 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -615,6 +615,7 @@ impl DataFrame { } } +// TODO: This will introduce a ref cycle (#2659) #[async_trait] impl TableProvider for DataFrame { fn as_any(&self) -> &dyn Any { @@ -632,6 +633,7 @@ impl TableProvider for DataFrame { async fn scan( &self, + _ctx: &SessionState, projection: &Option>, filters: &[Expr], limit: Option, diff --git a/datafusion/core/src/datasource/datasource.rs b/datafusion/core/src/datasource/datasource.rs index 8ab254525acb..17b288bef4e0 100644 --- a/datafusion/core/src/datasource/datasource.rs +++ b/datafusion/core/src/datasource/datasource.rs @@ -25,6 +25,7 @@ pub use datafusion_expr::{TableProviderFilterPushDown, TableType}; use crate::arrow::datatypes::SchemaRef; use crate::error::Result; +use crate::execution::context::SessionState; use crate::logical_plan::Expr; use crate::physical_plan::ExecutionPlan; @@ -47,6 +48,7 @@ pub trait TableProvider: Sync + Send { /// parallelized or distributed. async fn scan( &self, + ctx: &SessionState, projection: &Option>, filters: &[Expr], // limit can be used to reduce the amount scanned diff --git a/datafusion/core/src/datasource/empty.rs b/datafusion/core/src/datasource/empty.rs index 837cd7704460..3bc7a958c9e3 100644 --- a/datafusion/core/src/datasource/empty.rs +++ b/datafusion/core/src/datasource/empty.rs @@ -25,6 +25,7 @@ use async_trait::async_trait; use crate::datasource::{TableProvider, TableType}; use crate::error::Result; +use crate::execution::context::SessionState; use crate::logical_plan::Expr; use crate::physical_plan::project_schema; use crate::physical_plan::{empty::EmptyExec, ExecutionPlan}; @@ -57,6 +58,7 @@ impl TableProvider for EmptyTable { async fn scan( &self, + _ctx: &SessionState, projection: &Option>, _filters: &[Expr], _limit: Option, diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 34e44971d662..bc7ed26042a3 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -35,6 +35,7 @@ use crate::datasource::{ use crate::logical_expr::TableProviderFilterPushDown; use crate::{ error::{DataFusionError, Result}, + execution::context::SessionState, logical_plan::Expr, physical_plan::{ empty::EmptyExec, @@ -302,6 +303,7 @@ impl TableProvider for ListingTable { async fn scan( &self, + _ctx: &SessionState, projection: &Option>, filters: &[Expr], limit: Option, @@ -405,6 +407,7 @@ impl ListingTable { #[cfg(test)] mod tests { use crate::datasource::file_format::avro::DEFAULT_AVRO_EXTENSION; + use crate::prelude::SessionContext; use crate::{ datafusion_data_access::object_store::local::LocalFileSystem, datasource::file_format::{avro::AvroFormat, parquet::ParquetFormat}, @@ -417,10 +420,12 @@ mod tests { #[tokio::test] async fn read_single_file() -> Result<()> { + let ctx = SessionContext::new(); + let table = load_table("alltypes_plain.parquet").await?; let projection = None; let exec = table - .scan(&projection, &[], None) + .scan(&ctx.state(), &projection, &[], None) .await .expect("Scan table"); @@ -447,7 +452,9 @@ mod tests { .with_listing_options(opt) .with_schema(schema); let table = ListingTable::try_new(config)?; - let exec = table.scan(&None, &[], None).await?; + + let ctx = SessionContext::new(); + let exec = table.scan(&ctx.state(), &None, &[], None).await?; assert_eq!(exec.statistics().num_rows, Some(8)); assert_eq!(exec.statistics().total_byte_size, Some(671)); @@ -483,8 +490,9 @@ mod tests { // this will filter out the only file in the store let filter = Expr::not_eq(col("p1"), lit("v1")); + let ctx = SessionContext::new(); let scan = table - .scan(&None, &[filter], None) + .scan(&ctx.state(), &None, &[filter], None) .await .expect("Empty execution plan"); diff --git a/datafusion/core/src/datasource/memory.rs b/datafusion/core/src/datasource/memory.rs index adc26d2f41d3..62dca1ea04ab 100644 --- a/datafusion/core/src/datasource/memory.rs +++ b/datafusion/core/src/datasource/memory.rs @@ -29,7 +29,7 @@ use async_trait::async_trait; use crate::datasource::{TableProvider, TableType}; use crate::error::{DataFusionError, Result}; -use crate::execution::context::TaskContext; +use crate::execution::context::{SessionState, TaskContext}; use crate::logical_plan::Expr; use crate::physical_plan::common; use crate::physical_plan::memory::MemoryExec; @@ -65,18 +65,18 @@ impl MemTable { pub async fn load( t: Arc, output_partitions: Option, - context: Arc, + ctx: &SessionState, ) -> Result { let schema = t.schema(); - let exec = t.scan(&None, &[], None).await?; + let exec = t.scan(ctx, &None, &[], None).await?; let partition_count = exec.output_partitioning().partition_count(); let tasks = (0..partition_count) .map(|part_i| { - let context1 = context.clone(); + let task = Arc::new(TaskContext::from(ctx)); let exec = exec.clone(); tokio::spawn(async move { - let stream = exec.execute(part_i, context1.clone())?; + let stream = exec.execute(part_i, task)?; common::collect(stream).await }) }) @@ -103,7 +103,8 @@ impl MemTable { let mut output_partitions = vec![]; for i in 0..exec.output_partitioning().partition_count() { // execute this *output* partition and collect all batches - let mut stream = exec.execute(i, context.clone())?; + let task_ctx = Arc::new(TaskContext::from(ctx)); + let mut stream = exec.execute(i, task_ctx)?; let mut batches = vec![]; while let Some(result) = stream.next().await { batches.push(result?); @@ -133,6 +134,7 @@ impl TableProvider for MemTable { async fn scan( &self, + _ctx: &SessionState, projection: &Option>, _filters: &[Expr], _limit: Option, @@ -180,7 +182,10 @@ mod tests { let provider = MemTable::try_new(schema, vec![vec![batch]])?; // scan with projection - let exec = provider.scan(&Some(vec![2, 1]), &[], None).await?; + let exec = provider + .scan(&session_ctx.state(), &Some(vec![2, 1]), &[], None) + .await?; + let mut it = exec.execute(0, task_ctx)?; let batch2 = it.next().await.unwrap()?; assert_eq!(2, batch2.schema().fields().len()); @@ -212,7 +217,9 @@ mod tests { let provider = MemTable::try_new(schema, vec![vec![batch]])?; - let exec = provider.scan(&None, &[], None).await?; + let exec = provider + .scan(&session_ctx.state(), &None, &[], None) + .await?; let mut it = exec.execute(0, task_ctx)?; let batch1 = it.next().await.unwrap()?; assert_eq!(3, batch1.schema().fields().len()); @@ -223,6 +230,8 @@ mod tests { #[tokio::test] async fn test_invalid_projection() -> Result<()> { + let session_ctx = SessionContext::new(); + let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Int32, false), Field::new("b", DataType::Int32, false), @@ -242,7 +251,10 @@ mod tests { let projection: Vec = vec![0, 4]; - match provider.scan(&Some(projection), &[], None).await { + match provider + .scan(&session_ctx.state(), &Some(projection), &[], None) + .await + { Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => { assert_eq!( "\"project index 4 out of bounds, max field 3\"", @@ -368,7 +380,9 @@ mod tests { let provider = MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?; - let exec = provider.scan(&None, &[], None).await?; + let exec = provider + .scan(&session_ctx.state(), &None, &[], None) + .await?; let mut it = exec.execute(0, task_ctx)?; let batch1 = it.next().await.unwrap()?; assert_eq!(3, batch1.schema().fields().len()); diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index 3db76cee1414..18a43d3d4f09 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -24,17 +24,15 @@ use async_trait::async_trait; use crate::{ error::Result, - execution::context::SessionContext, logical_plan::{Expr, LogicalPlan}, physical_plan::ExecutionPlan, }; use crate::datasource::{TableProvider, TableType}; +use crate::execution::context::SessionState; /// An implementation of `TableProvider` that uses another logical plan. pub struct ViewTable { - /// To create ExecutionPlan - context: SessionContext, /// LogicalPlan of the view logical_plan: LogicalPlan, /// File fields + partition columns @@ -44,11 +42,10 @@ pub struct ViewTable { impl ViewTable { /// Create new view that is executed at query runtime. /// Takes a `LogicalPlan` as input. - pub fn try_new(context: SessionContext, logical_plan: LogicalPlan) -> Result { + pub fn try_new(logical_plan: LogicalPlan) -> Result { let table_schema = logical_plan.schema().as_ref().to_owned().into(); let view = Self { - context, logical_plan, table_schema, }; @@ -73,16 +70,18 @@ impl TableProvider for ViewTable { async fn scan( &self, + ctx: &SessionState, _projection: &Option>, _filters: &[Expr], _limit: Option, ) -> Result> { - self.context.create_physical_plan(&self.logical_plan).await + ctx.create_physical_plan(&self.logical_plan).await } } #[cfg(test)] mod tests { + use crate::prelude::SessionContext; use crate::{assert_batches_eq, execution::context::SessionConfig}; use super::*; diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 4d579776e666..ba3f86c69ec9 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -349,16 +349,14 @@ impl SessionContext { (true, Ok(_)) => { self.deregister_table(name.as_str())?; let plan = self.optimize(&input)?; - let table = - Arc::new(ViewTable::try_new(self.clone(), plan.clone())?); + let table = Arc::new(ViewTable::try_new(plan.clone())?); self.register_table(name.as_str(), table)?; Ok(Arc::new(DataFrame::new(self.state.clone(), &plan))) } (_, Err(_)) => { let plan = self.optimize(&input)?; - let table = - Arc::new(ViewTable::try_new(self.clone(), plan.clone())?); + let table = Arc::new(ViewTable::try_new(plan.clone())?); self.register_table(name.as_str(), table)?; Ok(Arc::new(DataFrame::new(self.state.clone(), &plan))) @@ -931,6 +929,11 @@ impl SessionContext { pub fn task_ctx(&self) -> Arc { Arc::new(TaskContext::from(self)) } + + /// Get a copy of the [`SessionState`] of this [`SessionContext`] + pub fn state(&self) -> SessionState { + self.state.read().clone() + } } impl FunctionRegistry for SessionContext { diff --git a/datafusion/core/src/physical_plan/planner.rs b/datafusion/core/src/physical_plan/planner.rs index 39e5e0000b10..ad957409c826 100644 --- a/datafusion/core/src/physical_plan/planner.rs +++ b/datafusion/core/src/physical_plan/planner.rs @@ -389,7 +389,7 @@ impl DefaultPhysicalPlanner { // referred to in the query let filters = unnormalize_cols(filters.iter().cloned()); let unaliased: Vec = filters.into_iter().map(unalias).collect(); - source.scan(projection, &unaliased, *limit).await + source.scan(session_state, projection, &unaliased, *limit).await } LogicalPlan::Values(Values { values, diff --git a/datafusion/core/tests/custom_sources.rs b/datafusion/core/tests/custom_sources.rs index f1356f7d4431..1e4ac6e5134e 100644 --- a/datafusion/core/tests/custom_sources.rs +++ b/datafusion/core/tests/custom_sources.rs @@ -30,7 +30,7 @@ use datafusion::{ }; use datafusion::{error::Result, physical_plan::DisplayFormatType}; -use datafusion::execution::context::{SessionContext, TaskContext}; +use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::logical_plan::{ col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE, }; @@ -201,6 +201,7 @@ impl TableProvider for CustomTableProvider { async fn scan( &self, + _state: &SessionState, projection: &Option>, _filters: &[Expr], _limit: Option, diff --git a/datafusion/core/tests/provider_filter_pushdown.rs b/datafusion/core/tests/provider_filter_pushdown.rs index 79c71afb341a..9b9ba84d3a5f 100644 --- a/datafusion/core/tests/provider_filter_pushdown.rs +++ b/datafusion/core/tests/provider_filter_pushdown.rs @@ -21,7 +21,7 @@ use arrow::record_batch::RecordBatch; use async_trait::async_trait; use datafusion::datasource::datasource::{TableProvider, TableType}; use datafusion::error::Result; -use datafusion::execution::context::{SessionContext, TaskContext}; +use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; use datafusion::logical_expr::{Expr, TableProviderFilterPushDown}; use datafusion::physical_plan::common::SizedRecordBatchStream; use datafusion::physical_plan::expressions::PhysicalSortExpr; @@ -138,6 +138,7 @@ impl TableProvider for CustomProvider { async fn scan( &self, + _state: &SessionState, _: &Option>, filters: &[Expr], _: Option, diff --git a/datafusion/core/tests/sql/information_schema.rs b/datafusion/core/tests/sql/information_schema.rs index a7b6bdb45175..1ed727be8e4a 100644 --- a/datafusion/core/tests/sql/information_schema.rs +++ b/datafusion/core/tests/sql/information_schema.rs @@ -16,6 +16,7 @@ // under the License. use async_trait::async_trait; +use datafusion::execution::context::SessionState; use datafusion::{ catalog::{ catalog::{CatalogProvider, MemoryCatalogProvider}, @@ -175,6 +176,7 @@ async fn information_schema_tables_table_types() { async fn scan( &self, + _ctx: &SessionState, _: &Option>, _: &[Expr], _: Option, diff --git a/datafusion/core/tests/sql/window.rs b/datafusion/core/tests/sql/window.rs index bdbc77067ebd..120028ac486b 100644 --- a/datafusion/core/tests/sql/window.rs +++ b/datafusion/core/tests/sql/window.rs @@ -328,7 +328,7 @@ async fn window_expr_eliminate() -> Result<()> { let plan = ctx .create_logical_plan(&("explain ".to_owned() + sql)) .expect(&msg); - let state = ctx.state.read().clone(); + let state = ctx.state(); let plan = state.optimize(&plan)?; let expected = vec![ "Explain [plan_type:Utf8, plan:Utf8]", diff --git a/datafusion/core/tests/statistics.rs b/datafusion/core/tests/statistics.rs index 99b53a62d8ee..95879ebaf679 100644 --- a/datafusion/core/tests/statistics.rs +++ b/datafusion/core/tests/statistics.rs @@ -34,7 +34,7 @@ use datafusion::{ }; use async_trait::async_trait; -use datafusion::execution::context::TaskContext; +use datafusion::execution::context::{SessionState, TaskContext}; /// This is a testing structure for statistics /// It will act both as a table provider and execution plan @@ -74,6 +74,7 @@ impl TableProvider for StatisticsValidation { async fn scan( &self, + _ctx: &SessionState, projection: &Option>, filters: &[Expr], // limit is ignored because it is not mandatory for a `TableProvider` to honor it