From fad77a4b278a441e343fb6caa76f6e3bcd27e248 Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Thu, 5 Jan 2023 11:03:19 +0000 Subject: [PATCH] Make SchemaProvider::table async (#4607) * Make SchemaProvider async (#3777) * Cleanup * Workaround for ShowCreate * Clippy and doctest * Update sqllogictests * Review feedback * Format --- datafusion-cli/Cargo.lock | 12 + .../examples/dataframe_in_memory.rs | 2 +- datafusion-examples/examples/simple_udaf.rs | 2 +- datafusion-examples/examples/simple_udf.rs | 2 +- datafusion/common/src/table_reference.rs | 6 + datafusion/core/Cargo.toml | 2 +- .../core/src/catalog/information_schema.rs | 25 +- datafusion/core/src/catalog/listing_schema.rs | 4 +- datafusion/core/src/catalog/schema.rs | 7 +- datafusion/core/src/dataframe.rs | 27 ++- datafusion/core/src/datasource/view.rs | 10 +- datafusion/core/src/execution/context.rs | 216 ++++++++++++------ datafusion/core/tests/dataframe.rs | 18 +- datafusion/core/tests/dataframe_functions.rs | 10 +- datafusion/core/tests/sql/create_drop.rs | 2 +- datafusion/core/tests/sql/errors.rs | 15 +- .../core/tests/sql/information_schema.rs | 3 +- datafusion/core/tests/sql/projection.rs | 4 +- datafusion/core/tests/sql/udf.rs | 2 +- .../tests/sqllogictests/src/insert/mod.rs | 4 +- .../tests/sqllogictests/test_files/ddl.slt | 4 +- .../test_files/information_schema.slt | 4 +- datafusion/proto/README.md | 4 +- .../proto/examples/logical_plan_serde.rs | 2 +- .../proto/examples/physical_plan_serde.rs | 2 +- datafusion/proto/src/logical_plan/mod.rs | 6 +- 26 files changed, 256 insertions(+), 139 deletions(-) diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index a34d9f4d486d..7eef17b4e336 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -2264,6 +2264,18 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db67dc6ef36edb658196c3fef0464a80b53dbbc194a904e81f9bd4190f9ecc5b" dependencies = [ "log", + "sqlparser_derive", +] + +[[package]] +name = "sqlparser_derive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55fe75cb4a364c7f7ae06c7dbbc8d84bddd85d6cdf9975963c3935bc1991761e" +dependencies = [ + "proc-macro2", + "quote", + "syn", ] [[package]] diff --git a/datafusion-examples/examples/dataframe_in_memory.rs b/datafusion-examples/examples/dataframe_in_memory.rs index 0702573e4fa0..be622d469f4d 100644 --- a/datafusion-examples/examples/dataframe_in_memory.rs +++ b/datafusion-examples/examples/dataframe_in_memory.rs @@ -47,7 +47,7 @@ async fn main() -> Result<()> { // declare a table in memory. In spark API, this corresponds to createDataFrame(...). ctx.register_batch("t", batch)?; - let df = ctx.table("t")?; + let df = ctx.table("t").await?; // construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL let filter = col("b").eq(lit(10)); diff --git a/datafusion-examples/examples/simple_udaf.rs b/datafusion-examples/examples/simple_udaf.rs index f4e0d3dd9793..d171f6579bfe 100644 --- a/datafusion-examples/examples/simple_udaf.rs +++ b/datafusion-examples/examples/simple_udaf.rs @@ -179,7 +179,7 @@ async fn main() -> Result<()> { // get a DataFrame from the context // this table has 1 column `a` f32 with values {2,4,8,64}, whose geometric mean is 8.0. - let df = ctx.table("t")?; + let df = ctx.table("t").await?; // perform the aggregation let df = df.aggregate(vec![], vec![geometric_mean.call(vec![col("a")])])?; diff --git a/datafusion-examples/examples/simple_udf.rs b/datafusion-examples/examples/simple_udf.rs index c9044a87abc9..f735f9938fe8 100644 --- a/datafusion-examples/examples/simple_udf.rs +++ b/datafusion-examples/examples/simple_udf.rs @@ -122,7 +122,7 @@ async fn main() -> Result<()> { let expr = pow.call(vec![col("a"), col("b")]); // get a DataFrame from the context - let df = ctx.table("t")?; + let df = ctx.table("t").await?; // if we do not have `pow` in the scope and we registered it, we can get it from the registry let pow = df.registry().udf("pow")?; diff --git a/datafusion/common/src/table_reference.rs b/datafusion/common/src/table_reference.rs index e547d7c03a60..0c74a35ae486 100644 --- a/datafusion/common/src/table_reference.rs +++ b/datafusion/common/src/table_reference.rs @@ -26,6 +26,12 @@ pub struct ResolvedTableReference<'a> { pub table: &'a str, } +impl<'a> std::fmt::Display for ResolvedTableReference<'a> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}.{}.{}", self.catalog, self.schema, self.table) + } +} + /// Represents a path to a table that may require further resolution #[derive(Debug, Clone, Copy)] pub enum TableReference<'a> { diff --git a/datafusion/core/Cargo.toml b/datafusion/core/Cargo.toml index 44c25f593c4a..098c3099f72b 100644 --- a/datafusion/core/Cargo.toml +++ b/datafusion/core/Cargo.toml @@ -91,7 +91,7 @@ pyo3 = { version = "0.17.1", optional = true } rand = "0.8" rayon = { version = "1.5", optional = true } smallvec = { version = "1.6", features = ["union"] } -sqlparser = "0.30" +sqlparser = { version = "0.30", features = ["visitor"] } tempfile = "3" tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "fs", "parking_lot"] } tokio-stream = "0.1" diff --git a/datafusion/core/src/catalog/information_schema.rs b/datafusion/core/src/catalog/information_schema.rs index cad20ebfed64..7119558e8763 100644 --- a/datafusion/core/src/catalog/information_schema.rs +++ b/datafusion/core/src/catalog/information_schema.rs @@ -19,6 +19,7 @@ //! //! Information Schema] +use async_trait::async_trait; use std::{any::Any, sync::Arc}; use arrow::{ @@ -43,6 +44,9 @@ pub const VIEWS: &str = "views"; pub const COLUMNS: &str = "columns"; pub const DF_SETTINGS: &str = "df_settings"; +/// All information schema tables +pub const INFORMATION_SCHEMA_TABLES: &[&str] = &[TABLES, VIEWS, COLUMNS, DF_SETTINGS]; + /// Implements the `information_schema` virtual schema and tables /// /// The underlying tables in the `information_schema` are created on @@ -69,7 +73,7 @@ struct InformationSchemaConfig { impl InformationSchemaConfig { /// Construct the `information_schema.tables` virtual table - fn make_tables(&self, builder: &mut InformationSchemaTablesBuilder) { + async fn make_tables(&self, builder: &mut InformationSchemaTablesBuilder) { // create a mem table with the names of tables for catalog_name in self.catalog_list.catalog_names() { @@ -79,7 +83,7 @@ impl InformationSchemaConfig { if schema_name != INFORMATION_SCHEMA { let schema = catalog.schema(&schema_name).unwrap(); for table_name in schema.table_names() { - let table = schema.table(&table_name).unwrap(); + let table = schema.table(&table_name).await.unwrap(); builder.add_table( &catalog_name, &schema_name, @@ -108,7 +112,7 @@ impl InformationSchemaConfig { } } - fn make_views(&self, builder: &mut InformationSchemaViewBuilder) { + async fn make_views(&self, builder: &mut InformationSchemaViewBuilder) { for catalog_name in self.catalog_list.catalog_names() { let catalog = self.catalog_list.catalog(&catalog_name).unwrap(); @@ -116,7 +120,7 @@ impl InformationSchemaConfig { if schema_name != INFORMATION_SCHEMA { let schema = catalog.schema(&schema_name).unwrap(); for table_name in schema.table_names() { - let table = schema.table(&table_name).unwrap(); + let table = schema.table(&table_name).await.unwrap(); builder.add_view( &catalog_name, &schema_name, @@ -130,7 +134,7 @@ impl InformationSchemaConfig { } /// Construct the `information_schema.columns` virtual table - fn make_columns(&self, builder: &mut InformationSchemaColumnsBuilder) { + async fn make_columns(&self, builder: &mut InformationSchemaColumnsBuilder) { for catalog_name in self.catalog_list.catalog_names() { let catalog = self.catalog_list.catalog(&catalog_name).unwrap(); @@ -138,7 +142,7 @@ impl InformationSchemaConfig { if schema_name != INFORMATION_SCHEMA { let schema = catalog.schema(&schema_name).unwrap(); for table_name in schema.table_names() { - let table = schema.table(&table_name).unwrap(); + let table = schema.table(&table_name).await.unwrap(); for (i, field) in table.schema().fields().iter().enumerate() { builder.add_column( &catalog_name, @@ -168,6 +172,7 @@ impl InformationSchemaConfig { } } +#[async_trait] impl SchemaProvider for InformationSchemaProvider { fn as_any(&self) -> &(dyn Any + 'static) { self @@ -182,7 +187,7 @@ impl SchemaProvider for InformationSchemaProvider { ] } - fn table(&self, name: &str) -> Option> { + async fn table(&self, name: &str) -> Option> { let config = self.config.clone(); let table: Arc = if name.eq_ignore_ascii_case("tables") { Arc::new(InformationSchemaTables::new(config)) @@ -246,7 +251,7 @@ impl PartitionStream for InformationSchemaTables { self.schema.clone(), // TODO: Stream this futures::stream::once(async move { - config.make_tables(&mut builder); + config.make_tables(&mut builder).await; Ok(builder.finish()) }), )) @@ -337,7 +342,7 @@ impl PartitionStream for InformationSchemaViews { self.schema.clone(), // TODO: Stream this futures::stream::once(async move { - config.make_views(&mut builder); + config.make_views(&mut builder).await; Ok(builder.finish()) }), )) @@ -451,7 +456,7 @@ impl PartitionStream for InformationSchemaColumns { self.schema.clone(), // TODO: Stream this futures::stream::once(async move { - config.make_columns(&mut builder); + config.make_columns(&mut builder).await; Ok(builder.finish()) }), )) diff --git a/datafusion/core/src/catalog/listing_schema.rs b/datafusion/core/src/catalog/listing_schema.rs index 265e08f7a6a3..32ee9f62ac3d 100644 --- a/datafusion/core/src/catalog/listing_schema.rs +++ b/datafusion/core/src/catalog/listing_schema.rs @@ -20,6 +20,7 @@ use crate::catalog::schema::SchemaProvider; use crate::datasource::datasource::TableProviderFactory; use crate::datasource::TableProvider; use crate::execution::context::SessionState; +use async_trait::async_trait; use datafusion_common::parsers::CompressionTypeVariant; use datafusion_common::{DFSchema, DataFusionError, OwnedTableReference}; use datafusion_expr::CreateExternalTable; @@ -156,6 +157,7 @@ impl ListingSchemaProvider { } } +#[async_trait] impl SchemaProvider for ListingSchemaProvider { fn as_any(&self) -> &dyn Any { self @@ -170,7 +172,7 @@ impl SchemaProvider for ListingSchemaProvider { .collect() } - fn table(&self, name: &str) -> Option> { + async fn table(&self, name: &str) -> Option> { self.tables .lock() .expect("Can't lock tables") diff --git a/datafusion/core/src/catalog/schema.rs b/datafusion/core/src/catalog/schema.rs index 41187c62965b..9d3b47546e39 100644 --- a/datafusion/core/src/catalog/schema.rs +++ b/datafusion/core/src/catalog/schema.rs @@ -18,6 +18,7 @@ //! Describes the interface and built-in implementations of schemas, //! representing collections of named tables. +use async_trait::async_trait; use dashmap::DashMap; use std::any::Any; use std::sync::Arc; @@ -26,6 +27,7 @@ use crate::datasource::TableProvider; use crate::error::{DataFusionError, Result}; /// Represents a schema, comprising a number of named tables. +#[async_trait] pub trait SchemaProvider: Sync + Send { /// Returns the schema provider as [`Any`](std::any::Any) /// so that it can be downcast to a specific implementation. @@ -35,7 +37,7 @@ pub trait SchemaProvider: Sync + Send { fn table_names(&self) -> Vec; /// Retrieves a specific table from the schema by name, provided it exists. - fn table(&self, name: &str) -> Option>; + async fn table(&self, name: &str) -> Option>; /// If supported by the implementation, adds a new table to this schema. /// If a table of the same name existed before, it returns "Table already exists" error. @@ -85,6 +87,7 @@ impl Default for MemorySchemaProvider { } } +#[async_trait] impl SchemaProvider for MemorySchemaProvider { fn as_any(&self) -> &dyn Any { self @@ -97,7 +100,7 @@ impl SchemaProvider for MemorySchemaProvider { .collect() } - fn table(&self, name: &str) -> Option> { + async fn table(&self, name: &str) -> Option> { self.tables.get(name).map(|table| table.value().clone()) } diff --git a/datafusion/core/src/dataframe.rs b/datafusion/core/src/dataframe.rs index f7241cb966cd..f0542e14985c 100644 --- a/datafusion/core/src/dataframe.rs +++ b/datafusion/core/src/dataframe.rs @@ -917,7 +917,7 @@ mod tests { let ctx = SessionContext::new(); ctx.register_batch("t", batch)?; - let df = ctx.table("t")?.select_columns(&["f.c1"])?; + let df = ctx.table("t").await?.select_columns(&["f.c1"])?; let df_results = df.collect().await?; @@ -1036,7 +1036,7 @@ mod tests { )); // build query with a UDF using DataFrame API - let df = ctx.table("aggregate_test_100")?; + let df = ctx.table("aggregate_test_100").await?; let expr = df.registry().udf("my_fn")?.call(vec![col("c12")]); let df = df.select(vec![expr])?; @@ -1101,7 +1101,7 @@ mod tests { ctx.register_table("test_table", Arc::new(df_impl.clone()))?; // pull the table out - let table = ctx.table("test_table")?; + let table = ctx.table("test_table").await?; let group_expr = vec![col("c1")]; let aggr_expr = vec![sum(col("c12"))]; @@ -1161,7 +1161,7 @@ mod tests { async fn test_table_with_name(name: &str) -> Result { let mut ctx = SessionContext::new(); register_aggregate_csv(&mut ctx, name).await?; - ctx.table(name) + ctx.table(name).await } async fn test_table() -> Result { @@ -1301,8 +1301,15 @@ mod tests { ctx.register_table("t1", table.clone())?; ctx.register_table("t2", table)?; let df = ctx - .table("t1")? - .join(ctx.table("t2")?, JoinType::Inner, &["c1"], &["c1"], None)? + .table("t1") + .await? + .join( + ctx.table("t2").await?, + JoinType::Inner, + &["c1"], + &["c1"], + None, + )? .sort(vec![ // make the test deterministic col("t1.c1").sort(true, true), @@ -1379,10 +1386,11 @@ mod tests { ) .await?; - ctx.register_table("t1", Arc::new(ctx.table("test")?))?; + ctx.register_table("t1", Arc::new(ctx.table("test").await?))?; let df = ctx - .table("t1")? + .table("t1") + .await? .filter(col("id").eq(lit(1)))? .select_columns(&["bool_col", "int_col"])?; @@ -1463,7 +1471,8 @@ mod tests { ctx.register_batch("t", batch)?; let df = ctx - .table("t")? + .table("t") + .await? // try and create a column with a '.' in it .with_column("f.c2", lit("hello"))?; diff --git a/datafusion/core/src/datasource/view.rs b/datafusion/core/src/datasource/view.rs index 4fae03fad717..2d2f33dc2051 100644 --- a/datafusion/core/src/datasource/view.rs +++ b/datafusion/core/src/datasource/view.rs @@ -428,12 +428,13 @@ mod tests { ) .await?; - ctx.register_table("t1", Arc::new(ctx.table("test")?))?; + ctx.register_table("t1", Arc::new(ctx.table("test").await?))?; ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?; let df = ctx - .table("t2")? + .table("t2") + .await? .filter(col("id").eq(lit(1)))? .select_columns(&["bool_col", "int_col"])?; @@ -457,12 +458,13 @@ mod tests { ) .await?; - ctx.register_table("t1", Arc::new(ctx.table("test")?))?; + ctx.register_table("t1", Arc::new(ctx.table("test").await?))?; ctx.sql("CREATE VIEW t2 as SELECT * FROM t1").await?; let df = ctx - .table("t2")? + .table("t2") + .await? .limit(0, Some(10))? .select_columns(&["bool_col", "int_col"])?; diff --git a/datafusion/core/src/execution/context.rs b/datafusion/core/src/execution/context.rs index 32cf8d165e2a..ba8accdcee46 100644 --- a/datafusion/core/src/execution/context.rs +++ b/datafusion/core/src/execution/context.rs @@ -30,6 +30,7 @@ use crate::{ pub use datafusion_physical_expr::execution_props::ExecutionProps; use datafusion_physical_expr::var_provider::is_system_variables; use parking_lot::RwLock; +use std::ops::ControlFlow; use std::sync::Arc; use std::{ any::{Any, TypeId}, @@ -94,6 +95,7 @@ use crate::physical_optimizer::optimize_sorts::OptimizeSorts; use crate::physical_optimizer::pipeline_checker::PipelineChecker; use crate::physical_optimizer::pipeline_fixer::PipelineFixer; use datafusion_optimizer::OptimizerConfig; +use datafusion_sql::planner::object_name_to_table_reference; use uuid::Uuid; use super::options::{ @@ -236,23 +238,14 @@ impl SessionContext { /// Creates a [`DataFrame`] that will execute a SQL query. /// - /// This method is `async` because queries of type `CREATE EXTERNAL TABLE` - /// might require the schema to be inferred. + /// Note: This api implements DDL such as `CREATE TABLE` and `CREATE VIEW` with in memory + /// default implementations. + /// + /// If this is not desirable, consider using [`SessionState::create_logical_plan()`] which + /// does not mutate the state based on such statements. pub async fn sql(&self, sql: &str) -> Result { - let mut statements = DFParser::parse_sql(sql)?; - if statements.len() != 1 { - return Err(DataFusionError::NotImplemented( - "The context currently only supports a single SQL statement".to_string(), - )); - } - // create a query planner - let plan = { - // TODO: Move catalog off SessionState onto SessionContext - let state = self.state.read(); - let query_planner = SqlToRel::new(&*state); - query_planner.statement_to_plan(statements.pop_front().unwrap())? - }; + let plan = self.state().create_logical_plan(sql).await?; match plan { LogicalPlan::CreateExternalTable(cmd) => { @@ -266,7 +259,7 @@ impl SessionContext { or_replace, }) => { let input = Arc::try_unwrap(input).unwrap_or_else(|e| e.as_ref().clone()); - let table = self.table(&name); + let table = self.table(&name).await; match (if_not_exists, or_replace, table) { (true, false, Ok(_)) => self.return_empty_dataframe(), @@ -306,7 +299,7 @@ impl SessionContext { or_replace, definition, }) => { - let view = self.table(&name); + let view = self.table(&name).await; match (or_replace, view) { (true, Ok(_)) => { @@ -333,7 +326,7 @@ impl SessionContext { LogicalPlan::DropTable(DropTable { name, if_exists, .. }) => { - let result = self.find_and_deregister(&name, TableType::Base); + let result = self.find_and_deregister(&name, TableType::Base).await; match (result, if_exists) { (Ok(true), _) => self.return_empty_dataframe(), (_, true) => self.return_empty_dataframe(), @@ -346,7 +339,7 @@ impl SessionContext { LogicalPlan::DropView(DropView { name, if_exists, .. }) => { - let result = self.find_and_deregister(&name, TableType::View); + let result = self.find_and_deregister(&name, TableType::View).await; match (result, if_exists) { (Ok(true), _) => self.return_empty_dataframe(), (_, true) => self.return_empty_dataframe(), @@ -456,7 +449,7 @@ impl SessionContext { let table_provider: Arc = self.create_custom_table(cmd).await?; - let table = self.table(&cmd.name); + let table = self.table(&cmd.name).await; match (cmd.if_not_exists, table) { (true, Ok(_)) => self.return_empty_dataframe(), (_, Err(_)) => { @@ -490,43 +483,31 @@ impl SessionContext { Ok(table) } - fn find_and_deregister<'a>( + async fn find_and_deregister<'a>( &self, table_ref: impl Into>, table_type: TableType, ) -> Result { let table_ref = table_ref.into(); - let table_provider = self - .state - .read() - .schema_for_ref(table_ref)? - .table(table_ref.table()); + let maybe_schema = { + let state = self.state.read(); + let resolved = state.resolve_table_ref(table_ref); + state + .catalog_list + .catalog(resolved.catalog) + .and_then(|c| c.schema(resolved.schema)) + }; - if let Some(table_provider) = table_provider { - if table_provider.table_type() == table_type { - self.deregister_table(table_ref)?; - return Ok(true); + if let Some(schema) = maybe_schema { + if let Some(table_provider) = schema.table(table_ref.table()).await { + if table_provider.table_type() == table_type { + schema.deregister_table(table_ref.table())?; + return Ok(true); + } } } - Ok(false) - } - /// Creates a logical plan. - /// - /// This function is intended for internal use and should not be called directly. - #[deprecated(note = "Use SessionContext::sql which snapshots the SessionState")] - pub fn create_logical_plan(&self, sql: &str) -> Result { - let mut statements = DFParser::parse_sql(sql)?; - - if statements.len() != 1 { - return Err(DataFusionError::NotImplemented( - "The context currently only supports a single SQL statement".to_string(), - )); - } - // create a query planner - let state = self.state.read().clone(); - let query_planner = SqlToRel::new(&state); - query_planner.statement_to_plan(statements.pop_front().unwrap()) + Ok(false) } /// Registers a variable provider within this context. @@ -914,12 +895,12 @@ impl SessionContext { /// provided reference. /// /// [`register_table`]: SessionContext::register_table - pub fn table<'a>( + pub async fn table<'a>( &self, table_ref: impl Into>, ) -> Result { let table_ref = table_ref.into(); - let provider = self.table_provider(table_ref)?; + let provider = self.table_provider(table_ref).await?; let plan = LogicalPlanBuilder::scan( table_ref.table(), provider_as_source(Arc::clone(&provider)), @@ -929,14 +910,14 @@ impl SessionContext { Ok(DataFrame::new(self.state(), plan)) } - /// Return a [`TabelProvider`] for the specified table. - pub fn table_provider<'a>( + /// Return a [`TableProvider`] for the specified table. + pub async fn table_provider<'a>( &self, table_ref: impl Into>, ) -> Result> { let table_ref = table_ref.into(); let schema = self.state.read().schema_for_ref(table_ref)?; - match schema.table(table_ref.table()) { + match schema.table(table_ref.table()).await { Some(ref provider) => Ok(Arc::clone(provider)), _ => Err(DataFusionError::Plan(format!( "No table named '{}'", @@ -1640,6 +1621,99 @@ impl SessionState { self } + /// Creates a [`LogicalPlan`] from the provided SQL string + /// + /// See [`SessionContext::sql`] for a higher-level interface that also handles DDL + pub async fn create_logical_plan(&self, sql: &str) -> Result { + use crate::catalog::information_schema::INFORMATION_SCHEMA_TABLES; + use datafusion_sql::parser::Statement as DFStatement; + use sqlparser::ast::*; + use std::collections::hash_map::Entry; + + let mut statements = DFParser::parse_sql(sql)?; + if statements.len() != 1 { + return Err(DataFusionError::NotImplemented( + "The context currently only supports a single SQL statement".to_string(), + )); + } + let statement = statements.pop_front().unwrap(); + + // Getting `TableProviders` is async but planing is not -- thus pre-fetch + // table providers for all relations referenced in this query + let mut relations = hashbrown::HashSet::with_capacity(10); + + match &statement { + DFStatement::Statement(s) => { + struct RelationVisitor<'a>(&'a mut hashbrown::HashSet); + + impl<'a> Visitor for RelationVisitor<'a> { + type Break = (); + + fn pre_visit_relation( + &mut self, + relation: &ObjectName, + ) -> ControlFlow<()> { + self.0.get_or_insert_with(relation, |_| relation.clone()); + ControlFlow::Continue(()) + } + + fn pre_visit_statement( + &mut self, + statement: &Statement, + ) -> ControlFlow<()> { + if let Statement::ShowCreate { + obj_type: ShowCreateObject::Table | ShowCreateObject::View, + obj_name, + } = statement + { + self.0.get_or_insert_with(obj_name, |_| obj_name.clone()); + } + ControlFlow::Continue(()) + } + } + let mut visitor = RelationVisitor(&mut relations); + let _ = s.as_ref().visit(&mut visitor); + } + DFStatement::CreateExternalTable(table) => { + relations.insert(ObjectName(vec![Ident::from(table.name.as_str())])); + } + DFStatement::DescribeTable(table) => { + relations + .get_or_insert_with(&table.table_name, |_| table.table_name.clone()); + } + } + + // Always include information_schema if available + if self.config.information_schema() { + for s in INFORMATION_SCHEMA_TABLES { + relations.insert(ObjectName(vec![ + Ident::new(INFORMATION_SCHEMA), + Ident::new(*s), + ])); + } + } + + let mut provider = SessionContextProvider { + state: self, + tables: HashMap::with_capacity(relations.len()), + }; + + for relation in relations { + let reference = object_name_to_table_reference(relation)?; + let resolved = self.resolve_table_ref(reference.as_table_reference()); + if let Entry::Vacant(v) = provider.tables.entry(resolved.to_string()) { + if let Ok(schema) = self.schema_for_ref(resolved) { + if let Some(table) = schema.table(resolved.table).await { + v.insert(provider_as_source(table)); + } + } + } + } + + let query = SqlToRel::new(&provider); + query.statement_to_plan(statement) + } + /// Optimizes the logical plan by applying optimizer rules. pub fn optimize(&self, plan: &LogicalPlan) -> Result { if let LogicalPlan::Explain(e) = plan { @@ -1668,6 +1742,8 @@ impl SessionState { } /// Creates a physical plan from a logical plan. + /// + /// Note: this first calls [`Self::optimize`] on the provided plan pub async fn create_physical_plan( &self, logical_plan: &LogicalPlan, @@ -1714,29 +1790,26 @@ impl SessionState { } } -impl ContextProvider for SessionState { +struct SessionContextProvider<'a> { + state: &'a SessionState, + tables: HashMap>, +} + +impl<'a> ContextProvider for SessionContextProvider<'a> { fn get_table_provider(&self, name: TableReference) -> Result> { - let resolved_ref = self.resolve_table_ref(name); - match self.schema_for_ref(resolved_ref) { - Ok(schema) => { - let provider = schema.table(resolved_ref.table).ok_or_else(|| { - DataFusionError::Plan(format!( - "table '{}.{}.{}' not found", - resolved_ref.catalog, resolved_ref.schema, resolved_ref.table - )) - })?; - Ok(provider_as_source(provider)) - } - Err(e) => Err(e), - } + let name = self.state.resolve_table_ref(name).to_string(); + self.tables + .get(&name) + .cloned() + .ok_or_else(|| DataFusionError::Plan(format!("table '{name}' not found"))) } fn get_function_meta(&self, name: &str) -> Option> { - self.scalar_functions.get(name).cloned() + self.state.scalar_functions.get(name).cloned() } fn get_aggregate_meta(&self, name: &str) -> Option> { - self.aggregate_functions.get(name).cloned() + self.state.aggregate_functions.get(name).cloned() } fn get_variable_type(&self, variable_names: &[String]) -> Option { @@ -1750,14 +1823,15 @@ impl ContextProvider for SessionState { VarType::UserDefined }; - self.execution_props + self.state + .execution_props .var_providers .as_ref() .and_then(|provider| provider.get(&provider_type)?.get_type(variable_names)) } fn options(&self) -> &ConfigOptions { - self.config_options() + self.state.config_options() } } diff --git a/datafusion/core/tests/dataframe.rs b/datafusion/core/tests/dataframe.rs index 190248efe847..e4d9d1c4e91b 100644 --- a/datafusion/core/tests/dataframe.rs +++ b/datafusion/core/tests/dataframe.rs @@ -64,11 +64,11 @@ async fn join() -> Result<()> { ctx.register_batch("aa", batch1)?; - let df1 = ctx.table("aa")?; + let df1 = ctx.table("aa").await?; ctx.register_batch("aaa", batch2)?; - let df2 = ctx.table("aaa")?; + let df2 = ctx.table("aaa").await?; let a = df1.join(df2, JoinType::Inner, &["a"], &["a"], None)?; @@ -100,6 +100,7 @@ async fn sort_on_unprojected_columns() -> Result<()> { let df = ctx .table("t") + .await .unwrap() .select(vec![col("a")]) .unwrap() @@ -138,6 +139,7 @@ async fn filter_with_alias_overwrite() -> Result<()> { let df = ctx .table("t") + .await .unwrap() .select(vec![(col("a").eq(lit(10))).alias("a")]) .unwrap() @@ -174,6 +176,7 @@ async fn select_with_alias_overwrite() -> Result<()> { let df = ctx .table("t") + .await .unwrap() .select(vec![col("a").alias("a")]) .unwrap() @@ -208,7 +211,8 @@ async fn test_grouping_sets() -> Result<()> { vec![col("a"), col("b")], ])); - let df = create_test_table()? + let df = create_test_table() + .await? .aggregate(vec![grouping_set_expr], vec![count(col("a"))])? .sort(vec![ Expr::Sort(Sort::new(Box::new(col("a")), false, true)), @@ -354,8 +358,8 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> { #[tokio::test] async fn join_with_alias_filter() -> Result<()> { let join_ctx = create_join_context()?; - let t1 = join_ctx.table("t1")?; - let t2 = join_ctx.table("t2")?; + let t1 = join_ctx.table("t1").await?; + let t2 = join_ctx.table("t2").await?; let t1_schema = t1.schema().clone(); let t2_schema = t2.schema().clone(); @@ -407,7 +411,7 @@ async fn join_with_alias_filter() -> Result<()> { Ok(()) } -fn create_test_table() -> Result { +async fn create_test_table() -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Int32, false), @@ -431,7 +435,7 @@ fn create_test_table() -> Result { ctx.register_batch("test", batch)?; - ctx.table("test") + ctx.table("test").await } async fn aggregates_table(ctx: &SessionContext) -> Result { diff --git a/datafusion/core/tests/dataframe_functions.rs b/datafusion/core/tests/dataframe_functions.rs index 624291a952df..c6291dc36c46 100644 --- a/datafusion/core/tests/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe_functions.rs @@ -34,7 +34,7 @@ use datafusion::execution::context::SessionContext; use datafusion::assert_batches_eq; use datafusion_expr::{approx_median, cast}; -fn create_test_table() -> Result { +async fn create_test_table() -> Result { let schema = Arc::new(Schema::new(vec![ Field::new("a", DataType::Utf8, false), Field::new("b", DataType::Int32, false), @@ -58,7 +58,7 @@ fn create_test_table() -> Result { ctx.register_batch("test", batch)?; - ctx.table("test") + ctx.table("test").await } /// Excutes an expression on the test dataframe as a select. @@ -69,7 +69,7 @@ macro_rules! assert_fn_batches { assert_fn_batches!($EXPR, $EXPECTED, 10) }; ($EXPR:expr, $EXPECTED: expr, $LIMIT: expr) => { - let df = create_test_table()?; + let df = create_test_table().await?; let df = df.select(vec![$EXPR])?.limit(0, Some($LIMIT))?; let batches = df.collect().await?; @@ -162,7 +162,7 @@ async fn test_fn_approx_median() -> Result<()> { "+----------------------+", ]; - let df = create_test_table()?; + let df = create_test_table().await?; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; assert_batches_eq!(expected, &batches); @@ -182,7 +182,7 @@ async fn test_fn_approx_percentile_cont() -> Result<()> { "+-------------------------------------------+", ]; - let df = create_test_table()?; + let df = create_test_table().await?; let batches = df.aggregate(vec![], vec![expr]).unwrap().collect().await?; assert_batches_eq!(expected, &batches); diff --git a/datafusion/core/tests/sql/create_drop.rs b/datafusion/core/tests/sql/create_drop.rs index 1d32d5e0d93c..1e3331db6b31 100644 --- a/datafusion/core/tests/sql/create_drop.rs +++ b/datafusion/core/tests/sql/create_drop.rs @@ -96,7 +96,7 @@ async fn create_external_table_with_ddl() -> Result<()> { let exists = schema.table_exist("dt"); assert!(exists, "Table should have been created!"); - let table_schema = schema.table("dt").unwrap().schema(); + let table_schema = schema.table("dt").await.unwrap().schema(); assert_eq!(3, table_schema.fields().len()); diff --git a/datafusion/core/tests/sql/errors.rs b/datafusion/core/tests/sql/errors.rs index ed15631a681d..f3a96c092b40 100644 --- a/datafusion/core/tests/sql/errors.rs +++ b/datafusion/core/tests/sql/errors.rs @@ -134,14 +134,15 @@ async fn invalid_qualified_table_references() -> Result<()> { } #[tokio::test] -#[allow(deprecated)] // TODO: Remove this test once create_logical_plan removed async fn unsupported_sql_returns_error() -> Result<()> { let ctx = SessionContext::new(); register_aggregate_csv(&ctx).await?; + let state = ctx.state(); + // create view let sql = "create view test_view as select * from aggregate_test_100"; - let plan = ctx.create_logical_plan(sql); - let physical_plan = ctx.create_physical_plan(&plan.unwrap()).await; + let plan = state.create_logical_plan(sql).await; + let physical_plan = state.create_physical_plan(&plan.unwrap()).await; assert!(physical_plan.is_err()); assert_eq!( format!("{}", physical_plan.unwrap_err()), @@ -150,8 +151,8 @@ async fn unsupported_sql_returns_error() -> Result<()> { ); // // drop view let sql = "drop view test_view"; - let plan = ctx.create_logical_plan(sql); - let physical_plan = ctx.create_physical_plan(&plan.unwrap()).await; + let plan = state.create_logical_plan(sql).await; + let physical_plan = state.create_physical_plan(&plan.unwrap()).await; assert!(physical_plan.is_err()); assert_eq!( format!("{}", physical_plan.unwrap_err()), @@ -160,8 +161,8 @@ async fn unsupported_sql_returns_error() -> Result<()> { ); // // drop table let sql = "drop table aggregate_test_100"; - let plan = ctx.create_logical_plan(sql); - let physical_plan = ctx.create_physical_plan(&plan.unwrap()).await; + let plan = state.create_logical_plan(sql).await; + let physical_plan = state.create_physical_plan(&plan.unwrap()).await; assert!(physical_plan.is_err()); assert_eq!( format!("{}", physical_plan.unwrap_err()), diff --git a/datafusion/core/tests/sql/information_schema.rs b/datafusion/core/tests/sql/information_schema.rs index 28b434f025da..9cd5fdb4e259 100644 --- a/datafusion/core/tests/sql/information_schema.rs +++ b/datafusion/core/tests/sql/information_schema.rs @@ -411,8 +411,7 @@ async fn information_schema_columns_not_exist_by_default() { .unwrap_err(); assert_eq!( err.to_string(), - // Error propagates from SessionState::schema_for_ref - "Error during planning: failed to resolve schema: information_schema" + "Error during planning: table 'datafusion.information_schema.columns' not found" ); } diff --git a/datafusion/core/tests/sql/projection.rs b/datafusion/core/tests/sql/projection.rs index 8bb33a91d458..b4627c5979bf 100644 --- a/datafusion/core/tests/sql/projection.rs +++ b/datafusion/core/tests/sql/projection.rs @@ -167,7 +167,7 @@ async fn projection_on_table_scan() -> Result<()> { let partition_count = 4; let ctx = partitioned_csv::create_ctx(&tmp_dir, partition_count).await?; - let table = ctx.table("test")?; + let table = ctx.table("test").await?; let logical_plan = LogicalPlanBuilder::from(table.into_optimized_plan()?) .project(vec![col("c2")])? .build()?; @@ -208,7 +208,7 @@ async fn preserve_nullability_on_projection() -> Result<()> { let tmp_dir = TempDir::new()?; let ctx = partitioned_csv::create_ctx(&tmp_dir, 1).await?; - let schema: Schema = ctx.table("test").unwrap().schema().clone().into(); + let schema: Schema = ctx.table("test").await.unwrap().schema().clone().into(); assert!(!schema.field_with_name("c1")?.is_nullable()); let plan = scan_empty(None, &schema, None)? diff --git a/datafusion/core/tests/sql/udf.rs b/datafusion/core/tests/sql/udf.rs index 9ddc5f6141c5..0688aa319488 100644 --- a/datafusion/core/tests/sql/udf.rs +++ b/datafusion/core/tests/sql/udf.rs @@ -75,7 +75,7 @@ async fn scalar_udf() -> Result<()> { // from here on, we may be in a different scope. We would still like to be able // to call UDFs. - let t = ctx.table("t")?; + let t = ctx.table("t").await?; let plan = LogicalPlanBuilder::from(t.into_optimized_plan()?) .project(vec![ diff --git a/datafusion/core/tests/sqllogictests/src/insert/mod.rs b/datafusion/core/tests/sqllogictests/src/insert/mod.rs index 3412e4ad8db4..17d673302cbb 100644 --- a/datafusion/core/tests/sqllogictests/src/insert/mod.rs +++ b/datafusion/core/tests/sqllogictests/src/insert/mod.rs @@ -54,8 +54,8 @@ pub async fn insert(ctx: &SessionContext, insert_stmt: SQLStatement) -> Result Result<()> { ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) .await ?; - let plan = ctx.table("t1")?.to_logical_plan()?; + let plan = ctx.table("t1").await?.to_logical_plan()?; let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip)); @@ -84,7 +84,7 @@ async fn main() -> Result<()> { ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) .await ?; - let logical_plan = ctx.table("t1")?.to_logical_plan()?; + let logical_plan = ctx.table("t1").await?.to_logical_plan()?; let physical_plan = ctx.create_physical_plan(&logical_plan).await?; let bytes = physical_plan_to_bytes(physical_plan.clone())?; let physical_round_trip = physical_plan_from_bytes(&bytes, &ctx)?; diff --git a/datafusion/proto/examples/logical_plan_serde.rs b/datafusion/proto/examples/logical_plan_serde.rs index 0f8312372778..9f468638c150 100644 --- a/datafusion/proto/examples/logical_plan_serde.rs +++ b/datafusion/proto/examples/logical_plan_serde.rs @@ -24,7 +24,7 @@ async fn main() -> Result<()> { let ctx = SessionContext::new(); ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) .await?; - let plan = ctx.table("t1")?.into_optimized_plan()?; + let plan = ctx.table("t1").await?.into_optimized_plan()?; let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); diff --git a/datafusion/proto/examples/physical_plan_serde.rs b/datafusion/proto/examples/physical_plan_serde.rs index 803b9e3186a4..72e216074a16 100644 --- a/datafusion/proto/examples/physical_plan_serde.rs +++ b/datafusion/proto/examples/physical_plan_serde.rs @@ -24,7 +24,7 @@ async fn main() -> Result<()> { let ctx = SessionContext::new(); ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) .await?; - let dataframe = ctx.table("t1")?; + let dataframe = ctx.table("t1").await?; let physical_plan = dataframe.create_physical_plan().await?; let bytes = physical_plan_to_bytes(physical_plan.clone())?; let physical_round_trip = physical_plan_from_bytes(&bytes, &ctx)?; diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index 2a5626d510ba..ff94d4670091 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1429,7 +1429,7 @@ mod roundtrip_tests { let ctx = SessionContext::new(); ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) .await?; - let scan = ctx.table("t1")?.into_optimized_plan()?; + let scan = ctx.table("t1").await?.into_optimized_plan()?; let topk_plan = LogicalPlan::Extension(Extension { node: Arc::new(TopKPlanNode::new(3, scan, col("revenue"))), }); @@ -1523,7 +1523,7 @@ mod roundtrip_tests { ctx.sql(sql).await.unwrap(); let codec = TestTableProviderCodec {}; - let scan = ctx.table("t")?.into_optimized_plan()?; + let scan = ctx.table("t").await?.into_optimized_plan()?; let bytes = logical_plan_to_bytes_with_extension_codec(&scan, &codec)?; let logical_round_trip = logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; @@ -1589,7 +1589,7 @@ mod roundtrip_tests { let ctx = SessionContext::new(); ctx.register_csv("t1", "testdata/test.csv", CsvReadOptions::default()) .await?; - let plan = ctx.table("t1")?.into_optimized_plan()?; + let plan = ctx.table("t1").await?.into_optimized_plan()?; let bytes = logical_plan_to_bytes(&plan)?; let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}"));